mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Compare commits
5 Commits
v0.7.2
...
mathieu/ca
| Author | SHA1 | Date | |
|---|---|---|---|
| 311d453e41 | |||
| f286f0882c | |||
| ffcafb3bf2 | |||
| 27075d840c | |||
| 30b5cd45e3 |
77
.github/workflows/deploy.yml
vendored
77
.github/workflows/deploy.yml
vendored
@@ -8,30 +8,18 @@ env:
|
||||
ECR_REPOSITORY: reflector
|
||||
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- platform: linux/amd64
|
||||
runner: linux-amd64
|
||||
arch: amd64
|
||||
- platform: linux/arm64
|
||||
runner: linux-arm64
|
||||
arch: arm64
|
||||
|
||||
runs-on: ${{ matrix.runner }}
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
deployments: write
|
||||
contents: read
|
||||
|
||||
outputs:
|
||||
registry: ${{ steps.login-ecr.outputs.registry }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
uses: aws-actions/configure-aws-credentials@0e613a0980cbf65ed5b322eb7a1e075d28913a83
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
@@ -39,52 +27,21 @@ jobs:
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
uses: aws-actions/amazon-ecr-login@62f4f872db3836360b72999f4b87f1ff13310f3a
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
|
||||
- name: Build and push ${{ matrix.arch }}
|
||||
uses: docker/build-push-action@v5
|
||||
- name: Build and push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: server
|
||||
platforms: ${{ matrix.platform }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest-${{ matrix.arch }}
|
||||
cache-from: type=gha,scope=${{ matrix.arch }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
provenance: false
|
||||
|
||||
create-manifest:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build]
|
||||
|
||||
permissions:
|
||||
deployments: write
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Create and push multi-arch manifest
|
||||
run: |
|
||||
# Get the registry URL (since we can't easily access job outputs in matrix)
|
||||
ECR_REGISTRY=$(aws ecr describe-registry --query 'registryId' --output text).dkr.ecr.${{ env.AWS_REGION }}.amazonaws.com
|
||||
|
||||
docker manifest create \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-amd64 \
|
||||
$ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest-arm64
|
||||
|
||||
docker manifest push $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest
|
||||
|
||||
echo "✅ Multi-arch manifest pushed: $ECR_REGISTRY/${{ env.ECR_REPOSITORY }}:latest"
|
||||
tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
38
.github/workflows/test_server.yml
vendored
38
.github/workflows/test_server.yml
vendored
@@ -19,41 +19,29 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
enable-cache: true
|
||||
working-directory: server
|
||||
|
||||
- name: Tests
|
||||
run: |
|
||||
cd server
|
||||
uv run -m pytest -v tests
|
||||
|
||||
docker-amd64:
|
||||
runs-on: linux-amd64
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build AMD64
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build and push
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/amd64
|
||||
cache-from: type=gha,scope=amd64
|
||||
cache-to: type=gha,mode=max,scope=amd64
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
|
||||
docker-arm64:
|
||||
runs-on: linux-arm64
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build ARM64
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: server
|
||||
platforms: linux/arm64
|
||||
cache-from: type=gha,scope=arm64
|
||||
cache-to: type=gha,mode=max,scope=arm64
|
||||
github-token: ${{ secrets.GHA_CACHE_TOKEN }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
32
CHANGELOG.md
32
CHANGELOG.md
@@ -1,37 +1,5 @@
|
||||
# Changelog
|
||||
|
||||
## [0.7.2](https://github.com/Monadical-SAS/reflector/compare/v0.7.1...v0.7.2) (2025-08-21)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* docker image not loading libgomp.so.1 for torch ([#560](https://github.com/Monadical-SAS/reflector/issues/560)) ([773fccd](https://github.com/Monadical-SAS/reflector/commit/773fccd93e887c3493abc2e4a4864dddce610177))
|
||||
* include shared rooms to search ([#558](https://github.com/Monadical-SAS/reflector/issues/558)) ([499eced](https://github.com/Monadical-SAS/reflector/commit/499eced3360b84fb3a90e1c8a3b554290d21adc2))
|
||||
|
||||
## [0.7.1](https://github.com/Monadical-SAS/reflector/compare/v0.7.0...v0.7.1) (2025-08-21)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* webvtt db null expectation mismatch ([#556](https://github.com/Monadical-SAS/reflector/issues/556)) ([e67ad1a](https://github.com/Monadical-SAS/reflector/commit/e67ad1a4a2054467bfeb1e0258fbac5868aaaf21))
|
||||
|
||||
## [0.7.0](https://github.com/Monadical-SAS/reflector/compare/v0.6.1...v0.7.0) (2025-08-21)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* delete recording with transcript ([#547](https://github.com/Monadical-SAS/reflector/issues/547)) ([99cc984](https://github.com/Monadical-SAS/reflector/commit/99cc9840b3f5de01e0adfbfae93234042d706d13))
|
||||
* pipeline improvement with file processing, parakeet, silero-vad ([#540](https://github.com/Monadical-SAS/reflector/issues/540)) ([bcc29c9](https://github.com/Monadical-SAS/reflector/commit/bcc29c9e0050ae215f89d460e9d645aaf6a5e486))
|
||||
* postgresql migration and removal of sqlite in pytest ([#546](https://github.com/Monadical-SAS/reflector/issues/546)) ([cd1990f](https://github.com/Monadical-SAS/reflector/commit/cd1990f8f0fe1503ef5069512f33777a73a93d7f))
|
||||
* search backend ([#537](https://github.com/Monadical-SAS/reflector/issues/537)) ([5f9b892](https://github.com/Monadical-SAS/reflector/commit/5f9b89260c9ef7f3c921319719467df22830453f))
|
||||
* search frontend ([#551](https://github.com/Monadical-SAS/reflector/issues/551)) ([3657242](https://github.com/Monadical-SAS/reflector/commit/365724271ca6e615e3425125a69ae2b46ce39285))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* evaluation cli event wrap ([#536](https://github.com/Monadical-SAS/reflector/issues/536)) ([941c3db](https://github.com/Monadical-SAS/reflector/commit/941c3db0bdacc7b61fea412f3746cc5a7cb67836))
|
||||
* use structlog not logging ([#550](https://github.com/Monadical-SAS/reflector/issues/550)) ([27e2f81](https://github.com/Monadical-SAS/reflector/commit/27e2f81fda5232e53edc729d3e99c5ef03adbfe9))
|
||||
|
||||
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)
|
||||
|
||||
|
||||
|
||||
497
ICS_IMPLEMENTATION.md
Normal file
497
ICS_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,497 @@
|
||||
# ICS Calendar Integration - Implementation Guide
|
||||
|
||||
## Overview
|
||||
This document provides detailed implementation guidance for integrating ICS calendar feeds with Reflector rooms. Unlike CalDAV which requires complex authentication and protocol handling, ICS integration uses simple HTTP(S) fetching of calendar files.
|
||||
|
||||
## Key Differences from CalDAV Approach
|
||||
|
||||
| Aspect | CalDAV | ICS |
|
||||
|--------|--------|-----|
|
||||
| Protocol | WebDAV extension | HTTP/HTTPS GET |
|
||||
| Authentication | Username/password, OAuth | Tokens embedded in URL |
|
||||
| Data Access | Selective event queries | Full calendar download |
|
||||
| Implementation | Complex (caldav library) | Simple (requests + icalendar) |
|
||||
| Real-time Updates | Supported | Polling only |
|
||||
| Write Access | Yes | No (read-only) |
|
||||
|
||||
## Technical Architecture
|
||||
|
||||
### 1. ICS Fetching Service
|
||||
|
||||
```python
|
||||
# reflector/services/ics_sync.py
|
||||
|
||||
import requests
|
||||
from icalendar import Calendar
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
class ICSFetchService:
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({'User-Agent': 'Reflector/1.0'})
|
||||
|
||||
def fetch_ics(self, url: str) -> str:
|
||||
"""Fetch ICS file from URL (authentication via URL token if needed)."""
|
||||
response = self.session.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
def parse_ics(self, ics_content: str) -> Calendar:
|
||||
"""Parse ICS content into calendar object."""
|
||||
return Calendar.from_ical(ics_content)
|
||||
|
||||
def extract_room_events(self, calendar: Calendar, room_url: str) -> List[dict]:
|
||||
"""Extract events that match the room URL."""
|
||||
events = []
|
||||
|
||||
for component in calendar.walk():
|
||||
if component.name == "VEVENT":
|
||||
# Check if event matches this room
|
||||
if self._event_matches_room(component, room_url):
|
||||
events.append(self._parse_event(component))
|
||||
|
||||
return events
|
||||
|
||||
def _event_matches_room(self, event, room_url: str) -> bool:
|
||||
"""Check if event location or description contains room URL."""
|
||||
location = str(event.get('LOCATION', ''))
|
||||
description = str(event.get('DESCRIPTION', ''))
|
||||
|
||||
# Support various URL formats
|
||||
patterns = [
|
||||
room_url,
|
||||
room_url.replace('https://', ''),
|
||||
room_url.split('/')[-1], # Just room name
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
if pattern in location or pattern in description:
|
||||
return True
|
||||
|
||||
return False
|
||||
```
|
||||
|
||||
### 2. Database Schema
|
||||
|
||||
```sql
|
||||
-- Modify room table
|
||||
ALTER TABLE room ADD COLUMN ics_url TEXT; -- encrypted to protect embedded tokens
|
||||
ALTER TABLE room ADD COLUMN ics_fetch_interval INTEGER DEFAULT 300; -- seconds
|
||||
ALTER TABLE room ADD COLUMN ics_enabled BOOLEAN DEFAULT FALSE;
|
||||
ALTER TABLE room ADD COLUMN ics_last_sync TIMESTAMP;
|
||||
ALTER TABLE room ADD COLUMN ics_last_etag TEXT; -- for caching
|
||||
|
||||
-- Calendar events table
|
||||
CREATE TABLE calendar_event (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
room_id UUID REFERENCES room(id) ON DELETE CASCADE,
|
||||
external_id TEXT NOT NULL, -- ICS UID
|
||||
title TEXT,
|
||||
description TEXT,
|
||||
start_time TIMESTAMP NOT NULL,
|
||||
end_time TIMESTAMP NOT NULL,
|
||||
attendees JSONB,
|
||||
location TEXT,
|
||||
ics_raw_data TEXT, -- Store raw VEVENT for reference
|
||||
last_synced TIMESTAMP DEFAULT NOW(),
|
||||
is_deleted BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW(),
|
||||
UNIQUE(room_id, external_id)
|
||||
);
|
||||
|
||||
-- Index for efficient queries
|
||||
CREATE INDEX idx_calendar_event_room_start ON calendar_event(room_id, start_time);
|
||||
CREATE INDEX idx_calendar_event_deleted ON calendar_event(is_deleted) WHERE NOT is_deleted;
|
||||
```
|
||||
|
||||
### 3. Background Tasks
|
||||
|
||||
```python
|
||||
# reflector/worker/tasks/ics_sync.py
|
||||
|
||||
from celery import shared_task
|
||||
from datetime import datetime, timedelta
|
||||
import hashlib
|
||||
|
||||
@shared_task
|
||||
def sync_ics_calendars():
|
||||
"""Sync all enabled ICS calendars based on their fetch intervals."""
|
||||
rooms = Room.query.filter_by(ics_enabled=True).all()
|
||||
|
||||
for room in rooms:
|
||||
# Check if it's time to sync based on fetch interval
|
||||
if should_sync(room):
|
||||
sync_room_calendar.delay(room.id)
|
||||
|
||||
@shared_task
|
||||
def sync_room_calendar(room_id: str):
|
||||
"""Sync calendar for a specific room."""
|
||||
room = Room.query.get(room_id)
|
||||
if not room or not room.ics_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
# Fetch ICS file (decrypt URL first)
|
||||
service = ICSFetchService()
|
||||
decrypted_url = decrypt_ics_url(room.ics_url)
|
||||
ics_content = service.fetch_ics(decrypted_url)
|
||||
|
||||
# Check if content changed (using ETag or hash)
|
||||
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
|
||||
if room.ics_last_etag == content_hash:
|
||||
logger.info(f"No changes in ICS for room {room_id}")
|
||||
return
|
||||
|
||||
# Parse and extract events
|
||||
calendar = service.parse_ics(ics_content)
|
||||
events = service.extract_room_events(calendar, room.url)
|
||||
|
||||
# Update database
|
||||
sync_events_to_database(room_id, events)
|
||||
|
||||
# Update sync metadata
|
||||
room.ics_last_sync = datetime.utcnow()
|
||||
room.ics_last_etag = content_hash
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync ICS for room {room_id}: {e}")
|
||||
|
||||
def should_sync(room) -> bool:
|
||||
"""Check if room calendar should be synced."""
|
||||
if not room.ics_last_sync:
|
||||
return True
|
||||
|
||||
time_since_sync = datetime.utcnow() - room.ics_last_sync
|
||||
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
||||
```
|
||||
|
||||
### 4. Celery Beat Schedule
|
||||
|
||||
```python
|
||||
# reflector/worker/celeryconfig.py
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
beat_schedule = {
|
||||
'sync-ics-calendars': {
|
||||
'task': 'reflector.worker.tasks.ics_sync.sync_ics_calendars',
|
||||
'schedule': 60.0, # Check every minute which calendars need syncing
|
||||
},
|
||||
'pre-create-meetings': {
|
||||
'task': 'reflector.worker.tasks.ics_sync.pre_create_calendar_meetings',
|
||||
'schedule': 60.0, # Check every minute for upcoming meetings
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Room ICS Configuration
|
||||
|
||||
```python
|
||||
# PATCH /v1/rooms/{room_id}
|
||||
{
|
||||
"ics_url": "https://calendar.google.com/calendar/ical/.../private-token/basic.ics",
|
||||
"ics_fetch_interval": 300, # seconds
|
||||
"ics_enabled": true
|
||||
# URL will be encrypted in database to protect embedded tokens
|
||||
}
|
||||
```
|
||||
|
||||
### Manual Sync Trigger
|
||||
|
||||
```python
|
||||
# POST /v1/rooms/{room_name}/ics/sync
|
||||
# Response:
|
||||
{
|
||||
"status": "syncing",
|
||||
"last_sync": "2024-01-15T10:30:00Z",
|
||||
"events_found": 5
|
||||
}
|
||||
```
|
||||
|
||||
### ICS Status
|
||||
|
||||
```python
|
||||
# GET /v1/rooms/{room_name}/ics/status
|
||||
# Response:
|
||||
{
|
||||
"enabled": true,
|
||||
"last_sync": "2024-01-15T10:30:00Z",
|
||||
"next_sync": "2024-01-15T10:35:00Z",
|
||||
"fetch_interval": 300,
|
||||
"events_count": 12,
|
||||
"upcoming_events": 3
|
||||
}
|
||||
```
|
||||
|
||||
## ICS Parsing Details
|
||||
|
||||
### Event Field Mapping
|
||||
|
||||
| ICS Field | Database Field | Notes |
|
||||
|-----------|---------------|-------|
|
||||
| UID | external_id | Unique identifier |
|
||||
| SUMMARY | title | Event title |
|
||||
| DESCRIPTION | description | Full description |
|
||||
| DTSTART | start_time | Convert to UTC |
|
||||
| DTEND | end_time | Convert to UTC |
|
||||
| LOCATION | location | Check for room URL |
|
||||
| ATTENDEE | attendees | Parse into JSON |
|
||||
| ORGANIZER | attendees | Add as organizer |
|
||||
| STATUS | (internal) | Filter cancelled events |
|
||||
|
||||
### Handling Recurring Events
|
||||
|
||||
```python
|
||||
def expand_recurring_events(event, start_date, end_date):
|
||||
"""Expand recurring events into individual occurrences."""
|
||||
from dateutil.rrule import rrulestr
|
||||
|
||||
if 'RRULE' not in event:
|
||||
return [event]
|
||||
|
||||
# Parse recurrence rule
|
||||
rrule_str = event['RRULE'].to_ical().decode()
|
||||
dtstart = event['DTSTART'].dt
|
||||
|
||||
# Generate occurrences
|
||||
rrule = rrulestr(rrule_str, dtstart=dtstart)
|
||||
occurrences = []
|
||||
|
||||
for dt in rrule.between(start_date, end_date):
|
||||
# Clone event with new date
|
||||
occurrence = event.copy()
|
||||
occurrence['DTSTART'].dt = dt
|
||||
if 'DTEND' in event:
|
||||
duration = event['DTEND'].dt - event['DTSTART'].dt
|
||||
occurrence['DTEND'].dt = dt + duration
|
||||
|
||||
# Unique ID for each occurrence
|
||||
occurrence['UID'] = f"{event['UID']}_{dt.isoformat()}"
|
||||
occurrences.append(occurrence)
|
||||
|
||||
return occurrences
|
||||
```
|
||||
|
||||
### Timezone Handling
|
||||
|
||||
```python
|
||||
def normalize_datetime(dt):
|
||||
"""Convert various datetime formats to UTC."""
|
||||
import pytz
|
||||
from datetime import datetime
|
||||
|
||||
if hasattr(dt, 'dt'): # icalendar property
|
||||
dt = dt.dt
|
||||
|
||||
if isinstance(dt, datetime):
|
||||
if dt.tzinfo is None:
|
||||
# Assume local timezone if naive
|
||||
dt = pytz.timezone('UTC').localize(dt)
|
||||
else:
|
||||
# Convert to UTC
|
||||
dt = dt.astimezone(pytz.UTC)
|
||||
|
||||
return dt
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### 1. URL Validation
|
||||
|
||||
```python
|
||||
def validate_ics_url(url: str) -> bool:
|
||||
"""Validate ICS URL for security."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Must be HTTPS in production
|
||||
if not settings.DEBUG and parsed.scheme != 'https':
|
||||
return False
|
||||
|
||||
# Prevent local file access
|
||||
if parsed.scheme in ('file', 'ftp'):
|
||||
return False
|
||||
|
||||
# Prevent internal network access
|
||||
if is_internal_ip(parsed.hostname):
|
||||
return False
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
### 2. Rate Limiting
|
||||
|
||||
```python
|
||||
# Implement per-room rate limiting
|
||||
RATE_LIMITS = {
|
||||
'min_fetch_interval': 60, # Minimum 1 minute between fetches
|
||||
'max_requests_per_hour': 60, # Max 60 requests per hour per room
|
||||
'max_file_size': 10 * 1024 * 1024, # Max 10MB ICS file
|
||||
}
|
||||
```
|
||||
|
||||
### 3. ICS URL Encryption
|
||||
|
||||
```python
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
class URLEncryption:
|
||||
def __init__(self):
|
||||
self.cipher = Fernet(settings.ENCRYPTION_KEY)
|
||||
|
||||
def encrypt_url(self, url: str) -> str:
|
||||
"""Encrypt ICS URL to protect embedded tokens."""
|
||||
return self.cipher.encrypt(url.encode()).decode()
|
||||
|
||||
def decrypt_url(self, encrypted: str) -> str:
|
||||
"""Decrypt ICS URL for fetching."""
|
||||
return self.cipher.decrypt(encrypted.encode()).decode()
|
||||
|
||||
def mask_url(self, url: str) -> str:
|
||||
"""Mask sensitive parts of URL for display."""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
# Keep scheme, host, and path structure but mask tokens
|
||||
if '/private-' in parsed.path:
|
||||
# Google Calendar format
|
||||
parts = parsed.path.split('/private-')
|
||||
masked_path = parts[0] + '/private-***' + parts[1].split('/')[-1]
|
||||
elif 'token=' in url:
|
||||
# Query parameter token
|
||||
masked_path = parsed.path
|
||||
parsed = parsed._replace(query='token=***')
|
||||
else:
|
||||
# Generic masking of path segments that look like tokens
|
||||
import re
|
||||
masked_path = re.sub(r'/[a-zA-Z0-9]{20,}/', '/***/', parsed.path)
|
||||
|
||||
return urlunparse(parsed._replace(path=masked_path))
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### 1. Unit Tests
|
||||
|
||||
```python
|
||||
# tests/test_ics_sync.py
|
||||
|
||||
def test_ics_parsing():
|
||||
"""Test ICS file parsing."""
|
||||
ics_content = """BEGIN:VCALENDAR
|
||||
VERSION:2.0
|
||||
BEGIN:VEVENT
|
||||
UID:test-123
|
||||
SUMMARY:Team Meeting
|
||||
LOCATION:https://reflector.monadical.com/engineering
|
||||
DTSTART:20240115T100000Z
|
||||
DTEND:20240115T110000Z
|
||||
END:VEVENT
|
||||
END:VCALENDAR"""
|
||||
|
||||
service = ICSFetchService()
|
||||
calendar = service.parse_ics(ics_content)
|
||||
events = service.extract_room_events(
|
||||
calendar,
|
||||
"https://reflector.monadical.com/engineering"
|
||||
)
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0]['title'] == 'Team Meeting'
|
||||
```
|
||||
|
||||
### 2. Integration Tests
|
||||
|
||||
```python
|
||||
def test_full_sync_flow():
|
||||
"""Test complete sync workflow."""
|
||||
# Create room with ICS URL (encrypt URL to protect tokens)
|
||||
encryption = URLEncryption()
|
||||
room = Room(
|
||||
name="test-room",
|
||||
ics_url=encryption.encrypt_url("https://example.com/calendar.ics?token=secret"),
|
||||
ics_enabled=True
|
||||
)
|
||||
|
||||
# Mock ICS fetch
|
||||
with patch('requests.get') as mock_get:
|
||||
mock_get.return_value.text = sample_ics_content
|
||||
|
||||
# Run sync
|
||||
sync_room_calendar(room.id)
|
||||
|
||||
# Verify events created
|
||||
events = CalendarEvent.query.filter_by(room_id=room.id).all()
|
||||
assert len(events) > 0
|
||||
```
|
||||
|
||||
## Common ICS Provider Configurations
|
||||
|
||||
### Google Calendar
|
||||
- URL Format: `https://calendar.google.com/calendar/ical/{calendar_id}/private-{token}/basic.ics`
|
||||
- Authentication via token embedded in URL
|
||||
- Updates every 3-8 hours by default
|
||||
|
||||
### Outlook/Office 365
|
||||
- URL Format: `https://outlook.office365.com/owa/calendar/{id}/calendar.ics`
|
||||
- May include token in URL path or query parameters
|
||||
- Real-time updates
|
||||
|
||||
### Apple iCloud
|
||||
- URL Format: `webcal://p{XX}-caldav.icloud.com/published/2/{token}`
|
||||
- Convert webcal:// to https://
|
||||
- Token embedded in URL path
|
||||
- Public calendars only
|
||||
|
||||
### Nextcloud/ownCloud
|
||||
- URL Format: `https://cloud.example.com/remote.php/dav/public-calendars/{token}`
|
||||
- Token embedded in URL path
|
||||
- Configurable update frequency
|
||||
|
||||
## Migration from CalDAV
|
||||
|
||||
If migrating from an existing CalDAV implementation:
|
||||
|
||||
1. **Database Migration**: Rename fields from `caldav_*` to `ics_*`
|
||||
2. **URL Conversion**: Most CalDAV servers provide ICS export endpoints
|
||||
3. **Authentication**: Convert from username/password to URL-embedded tokens
|
||||
4. **Remove Dependencies**: Uninstall caldav library, add icalendar
|
||||
5. **Update Background Tasks**: Replace CalDAV sync with ICS fetch
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
1. **Caching**: Use ETag/Last-Modified headers to avoid refetching unchanged calendars
|
||||
2. **Incremental Sync**: Store last sync timestamp, only process new/modified events
|
||||
3. **Batch Processing**: Process multiple room calendars in parallel
|
||||
4. **Connection Pooling**: Reuse HTTP connections for multiple requests
|
||||
5. **Compression**: Support gzip encoding for large ICS files
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
### Metrics to Track
|
||||
- Sync success/failure rate per room
|
||||
- Average sync duration
|
||||
- ICS file sizes
|
||||
- Number of events processed
|
||||
- Failed event matches
|
||||
|
||||
### Debug Logging
|
||||
```python
|
||||
logger.debug(f"Fetching ICS from {room.ics_url}")
|
||||
logger.debug(f"ICS content size: {len(ics_content)} bytes")
|
||||
logger.debug(f"Found {len(events)} matching events")
|
||||
logger.debug(f"Event UIDs: {[e['external_id'] for e in events]}")
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
1. **SSL Certificate Errors**: Add certificate validation options
|
||||
2. **Timeout Issues**: Increase timeout for large calendars
|
||||
3. **Encoding Problems**: Handle various character encodings
|
||||
4. **Timezone Mismatches**: Always convert to UTC
|
||||
5. **Memory Issues**: Stream large ICS files instead of loading entirely
|
||||
337
PLAN.md
Normal file
337
PLAN.md
Normal file
@@ -0,0 +1,337 @@
|
||||
# ICS Calendar Integration Plan
|
||||
|
||||
## Core Concept
|
||||
ICS calendar URLs are attached to rooms (not users) to enable automatic meeting tracking and management through periodic fetching of calendar data.
|
||||
|
||||
## Database Schema Updates
|
||||
|
||||
### 1. Add ICS configuration to rooms
|
||||
- Add `ics_url` field to room table (URL to .ics file, may include auth token)
|
||||
- Add `ics_fetch_interval` field to room table (default: 5 minutes, configurable)
|
||||
- Add `ics_enabled` boolean field to room table
|
||||
- Add `ics_last_sync` timestamp field to room table
|
||||
|
||||
### 2. Create calendar_events table
|
||||
- `id` - UUID primary key
|
||||
- `room_id` - Foreign key to room
|
||||
- `external_id` - ICS event UID
|
||||
- `title` - Event title
|
||||
- `description` - Event description
|
||||
- `start_time` - Event start timestamp
|
||||
- `end_time` - Event end timestamp
|
||||
- `attendees` - JSON field with attendee list and status
|
||||
- `location` - Meeting location (should contain room name)
|
||||
- `last_synced` - Last sync timestamp
|
||||
- `is_deleted` - Boolean flag for soft delete (preserve past events)
|
||||
- `ics_raw_data` - TEXT field to store raw VEVENT data for reference
|
||||
|
||||
### 3. Update meeting table
|
||||
- Add `calendar_event_id` - Foreign key to calendar_events
|
||||
- Add `calendar_metadata` - JSON field for additional calendar data
|
||||
- Remove unique constraint on room_id + active status (allow multiple active meetings per room)
|
||||
|
||||
## Backend Implementation
|
||||
|
||||
### 1. ICS Sync Service
|
||||
- Create background task that runs based on room's `ics_fetch_interval` (default: 5 minutes)
|
||||
- For each room with ICS enabled, fetch the .ics file via HTTP/HTTPS
|
||||
- Parse ICS file using icalendar library
|
||||
- Extract VEVENT components and filter events looking for room URL (e.g., "https://reflector.monadical.com/max")
|
||||
- Store matching events in calendar_events table
|
||||
- Mark events as "upcoming" if start_time is within next 30 minutes
|
||||
- Pre-create Whereby meetings 1 minute before start (ensures no delay when users join)
|
||||
- Soft-delete future events that were removed from calendar (set is_deleted=true)
|
||||
- Never delete past events (preserve for historical record)
|
||||
- Support authenticated ICS feeds via tokens embedded in URL
|
||||
|
||||
### 2. Meeting Management Updates
|
||||
- Allow multiple active meetings per room
|
||||
- Pre-create meeting record 1 minute before calendar event starts (ensures meeting is ready)
|
||||
- Link meeting to calendar_event for metadata
|
||||
- Keep meeting active for 15 minutes after last participant leaves (grace period)
|
||||
- Don't auto-close if new participant joins within grace period
|
||||
|
||||
### 3. API Endpoints
|
||||
- `GET /v1/rooms/{room_name}/meetings` - List all active and upcoming meetings for a room
|
||||
- Returns filtered data based on user role (owner vs participant)
|
||||
- `GET /v1/rooms/{room_name}/meetings/upcoming` - List upcoming meetings (next 30 min)
|
||||
- Returns filtered data based on user role
|
||||
- `POST /v1/rooms/{room_name}/meetings/{meeting_id}/join` - Join specific meeting
|
||||
- `PATCH /v1/rooms/{room_id}` - Update room settings (including ICS configuration)
|
||||
- ICS fields only visible/editable by room owner
|
||||
- `POST /v1/rooms/{room_name}/ics/sync` - Trigger manual ICS sync
|
||||
- Only accessible by room owner
|
||||
- `GET /v1/rooms/{room_name}/ics/status` - Get ICS sync status and last fetch time
|
||||
- Only accessible by room owner
|
||||
|
||||
## Frontend Implementation
|
||||
|
||||
### 1. Room Settings Page
|
||||
- Add ICS configuration section
|
||||
- Field for ICS URL (e.g., Google Calendar private URL, Outlook ICS export)
|
||||
- Field for fetch interval (dropdown: 1 min, 5 min, 10 min, 30 min, 1 hour)
|
||||
- Test connection button (validates ICS file can be fetched and parsed)
|
||||
- Manual sync button
|
||||
- Show last sync time and next scheduled sync
|
||||
|
||||
### 2. Meeting Selection Page (New)
|
||||
- Show when accessing `/room/{room_name}`
|
||||
- **Host view** (room owner):
|
||||
- Full calendar event details
|
||||
- Meeting title and description
|
||||
- Complete attendee list with RSVP status
|
||||
- Number of current participants
|
||||
- Duration (how long it's been running)
|
||||
- **Participant view** (non-owners):
|
||||
- Meeting title only
|
||||
- Date and time
|
||||
- Number of current participants
|
||||
- Duration (how long it's been running)
|
||||
- No attendee list or description (privacy)
|
||||
- Display upcoming meetings (visible 30min before):
|
||||
- Show countdown to start
|
||||
- Can click to join early → redirected to waiting page
|
||||
- Waiting page shows countdown until meeting starts
|
||||
- Meeting pre-created by background task (ready when users arrive)
|
||||
- Option to create unscheduled meeting (uses existing flow)
|
||||
|
||||
### 3. Meeting Room Updates
|
||||
- Show calendar metadata in meeting info
|
||||
- Display invited attendees vs actual participants
|
||||
- Show meeting title from calendar event
|
||||
|
||||
## Meeting Lifecycle
|
||||
|
||||
### 1. Meeting Creation
|
||||
- Automatic: Pre-created 1 minute before calendar event starts (ensures Whereby room is ready)
|
||||
- Manual: User creates unscheduled meeting (existing `/rooms/{room_name}/meeting` endpoint)
|
||||
- Background task handles pre-creation to avoid delays when users join
|
||||
|
||||
### 2. Meeting Join Rules
|
||||
- Can join active meetings immediately
|
||||
- Can see upcoming meetings 30 minutes before start
|
||||
- Can click to join upcoming meetings early → sent to waiting page
|
||||
- Waiting page automatically transitions to meeting at scheduled time
|
||||
- Unscheduled meetings always joinable (current behavior)
|
||||
|
||||
### 3. Meeting Closure Rules
|
||||
- All meetings: 15-minute grace period after last participant leaves
|
||||
- If participant rejoins within grace period, keep meeting active
|
||||
- Calendar meetings: Force close 30 minutes after scheduled end time
|
||||
- Unscheduled meetings: Keep active for 8 hours (current behavior)
|
||||
|
||||
## ICS Parsing Logic
|
||||
|
||||
### 1. Event Matching
|
||||
- Parse ICS file using Python icalendar library
|
||||
- Iterate through VEVENT components
|
||||
- Check LOCATION field for full FQDN URL (e.g., "https://reflector.monadical.com/max")
|
||||
- Check DESCRIPTION for room URL or mention
|
||||
- Support multiple formats:
|
||||
- Full URL: "https://reflector.monadical.com/max"
|
||||
- With /room path: "https://reflector.monadical.com/room/max"
|
||||
- Partial paths: "room/max", "/max room"
|
||||
|
||||
### 2. Attendee Extraction
|
||||
- Parse ATTENDEE properties from VEVENT
|
||||
- Extract email (MAILTO), name (CN parameter), and RSVP status (PARTSTAT)
|
||||
- Store as JSON in calendar_events.attendees
|
||||
|
||||
### 3. Sync Strategy
|
||||
- Fetch complete ICS file (contains all events)
|
||||
- Filter events from (now - 1 hour) to (now + 24 hours) for processing
|
||||
- Update existing events if LAST-MODIFIED or SEQUENCE changed
|
||||
- Delete future events that no longer exist in ICS (start_time > now)
|
||||
- Keep past events for historical record (never delete if start_time < now)
|
||||
- Handle recurring events (RRULE) - expand to individual instances
|
||||
- Track deleted calendar events to clean up future meetings
|
||||
- Cache ICS file hash to detect changes and skip unnecessary processing
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### 1. ICS URL Security
|
||||
- ICS URLs may contain authentication tokens (e.g., Google Calendar private URLs)
|
||||
- Store full ICS URLs encrypted using Fernet to protect embedded tokens
|
||||
- Validate ICS URLs (must be HTTPS for production)
|
||||
- Never expose full ICS URLs in API responses (return masked version)
|
||||
- Rate limit ICS fetching to prevent abuse
|
||||
|
||||
### 2. Room Access
|
||||
- Only room owner can configure ICS URL
|
||||
- ICS URL shown as masked version to room owner (hides embedded tokens)
|
||||
- ICS settings not visible to other users
|
||||
- Meeting list visible to all room participants
|
||||
- ICS fetch logs only visible to room owner
|
||||
|
||||
### 3. Meeting Privacy
|
||||
- Full calendar details visible only to room owner
|
||||
- Participants see limited info: title, date/time only
|
||||
- Attendee list and description hidden from non-owners
|
||||
- Meeting titles visible in room listing to all
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Database and ICS Setup (Week 1) ✅ COMPLETED (2025-08-18)
|
||||
1. ✅ Created database migrations for ICS fields and calendar_events table
|
||||
- Added ics_url, ics_fetch_interval, ics_enabled, ics_last_sync, ics_last_etag to room table
|
||||
- Created calendar_event table with ics_uid (instead of external_id) and proper typing
|
||||
- Added calendar_event_id and calendar_metadata (JSONB) to meeting table
|
||||
- Removed server_default from datetime fields for consistency
|
||||
2. ✅ Installed icalendar Python library for ICS parsing
|
||||
- Added icalendar>=6.0.0 to dependencies
|
||||
- No encryption needed - ICS URLs are read-only
|
||||
3. ✅ Built ICS fetch and sync service
|
||||
- Simple HTTP fetching without unnecessary validation
|
||||
- Proper TypedDict typing for event data structures
|
||||
- Supports any standard ICS format
|
||||
- Event matching on full room URL only
|
||||
4. ✅ API endpoints for ICS configuration
|
||||
- Room model updated to support ICS fields via existing PATCH endpoint
|
||||
- POST /v1/rooms/{room_name}/ics/sync - Trigger manual sync (owner only)
|
||||
- GET /v1/rooms/{room_name}/ics/status - Get sync status (owner only)
|
||||
- GET /v1/rooms/{room_name}/meetings - List meetings with privacy controls
|
||||
- GET /v1/rooms/{room_name}/meetings/upcoming - List upcoming meetings
|
||||
5. ✅ Celery background tasks for periodic sync
|
||||
- sync_room_ics - Sync individual room calendar
|
||||
- sync_all_ics_calendars - Check all rooms and queue sync based on fetch intervals
|
||||
- pre_create_upcoming_meetings - Pre-create Whereby meetings 1 minute before start
|
||||
- Tasks scheduled in beat schedule (every minute for checking, respects individual intervals)
|
||||
6. ✅ Tests written and passing
|
||||
- 6 tests for Room ICS fields
|
||||
- 7 tests for CalendarEvent model
|
||||
- 7 tests for ICS sync service
|
||||
- 11 tests for API endpoints
|
||||
- 6 tests for background tasks
|
||||
- All 31 ICS-related tests passing
|
||||
|
||||
### Phase 2: Meeting Management (Week 2) ✅ COMPLETED (2025-08-19)
|
||||
1. ✅ Updated meeting lifecycle logic with grace period support
|
||||
- 15-minute grace period after last participant leaves
|
||||
- Automatic reactivation when participants rejoin
|
||||
- Force close calendar meetings 30 minutes after scheduled end
|
||||
2. ✅ Support multiple active meetings per room
|
||||
- Removed unique constraint on active meetings
|
||||
- Added get_all_active_for_room() method
|
||||
- Added get_active_by_calendar_event() method
|
||||
3. ✅ Implemented grace period logic
|
||||
- Added last_participant_left_at and grace_period_minutes fields
|
||||
- Process meetings task handles grace period checking
|
||||
- Whereby webhooks clear grace period on participant join
|
||||
4. ✅ Link meetings to calendar events
|
||||
- Pre-created meetings properly linked via calendar_event_id
|
||||
- Calendar metadata stored with meeting
|
||||
- API endpoints for listing and joining specific meetings
|
||||
|
||||
### Phase 3: Frontend Meeting Selection (Week 3)
|
||||
1. Build meeting selection page
|
||||
2. Show active and upcoming meetings
|
||||
3. Implement waiting page for early joiners
|
||||
4. Add automatic transition from waiting to meeting
|
||||
5. Support unscheduled meeting creation
|
||||
|
||||
### Phase 4: Calendar Integration UI (Week 4)
|
||||
1. Add ICS settings to room configuration
|
||||
2. Display calendar metadata in meetings
|
||||
3. Show attendee information
|
||||
4. Add sync status indicators
|
||||
5. Show fetch interval and next sync time
|
||||
|
||||
## Success Metrics
|
||||
- Zero merged meetings from consecutive calendar events
|
||||
- Successful ICS sync from major providers (Google Calendar, Outlook, Apple Calendar, Nextcloud)
|
||||
- Meeting join accuracy: correct meeting 100% of the time
|
||||
- Grace period prevents 90% of accidental meeting closures
|
||||
- Configurable fetch intervals reduce unnecessary API calls
|
||||
|
||||
## Design Decisions
|
||||
1. **ICS attached to room, not user** - Prevents duplicate meetings from multiple calendars
|
||||
2. **Multiple active meetings per room** - Supported with meeting selection page
|
||||
3. **Grace period for rejoining** - 15 minutes after last participant leaves
|
||||
4. **Upcoming meeting visibility** - Show 30 minutes before, join only on time
|
||||
5. **Calendar data storage** - Attached to meeting record for full context
|
||||
6. **No "ad-hoc" meetings** - Use existing meeting creation flow (unscheduled meetings)
|
||||
7. **ICS configuration via room PATCH** - Reuse existing room configuration endpoint
|
||||
8. **Event deletion handling** - Soft-delete future events, preserve past meetings
|
||||
9. **Configurable fetch interval** - Balance between freshness and server load
|
||||
10. **ICS over CalDAV** - Simpler implementation, wider compatibility, no complex auth
|
||||
|
||||
## Phase 2 Implementation Files
|
||||
|
||||
### Database Migrations
|
||||
- `/server/migrations/versions/6025e9b2bef2_remove_one_active_meeting_per_room_.py` - Remove unique constraint
|
||||
- `/server/migrations/versions/d4a1c446458c_add_grace_period_fields_to_meeting.py` - Add grace period fields
|
||||
|
||||
### Updated Models
|
||||
- `/server/reflector/db/meetings.py` - Added grace period fields and new query methods
|
||||
|
||||
### Updated Services
|
||||
- `/server/reflector/worker/process.py` - Enhanced with grace period logic and multiple meeting support
|
||||
|
||||
### Updated API
|
||||
- `/server/reflector/views/rooms.py` - Added endpoints for listing active meetings and joining specific meetings
|
||||
- `/server/reflector/views/whereby.py` - Clear grace period on participant join
|
||||
|
||||
### Tests
|
||||
- `/server/tests/test_multiple_active_meetings.py` - Comprehensive tests for Phase 2 features (5 tests)
|
||||
|
||||
## Phase 1 Implementation Files Created
|
||||
|
||||
### Database Models
|
||||
- `/server/reflector/db/rooms.py` - Updated with ICS fields (url, fetch_interval, enabled, last_sync, etag)
|
||||
- `/server/reflector/db/calendar_events.py` - New CalendarEvent model with ics_uid and proper typing
|
||||
- `/server/reflector/db/meetings.py` - Updated with calendar_event_id and calendar_metadata (JSONB)
|
||||
|
||||
### Services
|
||||
- `/server/reflector/services/ics_sync.py` - ICS fetching and parsing with TypedDict for proper typing
|
||||
|
||||
### API Endpoints
|
||||
- `/server/reflector/views/rooms.py` - Added ICS management endpoints with privacy controls
|
||||
|
||||
### Background Tasks
|
||||
- `/server/reflector/worker/ics_sync.py` - Celery tasks for automatic periodic sync
|
||||
- `/server/reflector/worker/app.py` - Updated beat schedule for ICS tasks
|
||||
|
||||
### Tests
|
||||
- `/server/tests/test_room_ics.py` - Room model ICS fields tests (6 tests)
|
||||
- `/server/tests/test_calendar_event.py` - CalendarEvent model tests (7 tests)
|
||||
- `/server/tests/test_ics_sync.py` - ICS sync service tests (7 tests)
|
||||
- `/server/tests/test_room_ics_api.py` - API endpoint tests (11 tests)
|
||||
- `/server/tests/test_ics_background_tasks.py` - Background task tests (6 tests)
|
||||
|
||||
### Key Design Decisions
|
||||
- No encryption needed - ICS URLs are read-only access
|
||||
- Using ics_uid instead of external_id for clarity
|
||||
- Proper TypedDict typing for event data structures
|
||||
- Removed unnecessary URL validation and webcal handling
|
||||
- calendar_metadata in meetings stores flexible calendar data (organizer, recurrence, etc)
|
||||
- Background tasks query all rooms directly to avoid filtering issues
|
||||
- Sync intervals respected per-room configuration
|
||||
|
||||
## Implementation Approach
|
||||
|
||||
### ICS Fetching vs CalDAV
|
||||
- **ICS Benefits**:
|
||||
- Simpler implementation (HTTP GET vs CalDAV protocol)
|
||||
- Wider compatibility (all calendar apps can export ICS)
|
||||
- No authentication complexity (simple URL with optional token)
|
||||
- Easier debugging (ICS is plain text)
|
||||
- Lower server requirements (no CalDAV library dependencies)
|
||||
|
||||
### Supported Calendar Providers
|
||||
1. **Google Calendar**: Private ICS URL from calendar settings
|
||||
2. **Outlook/Office 365**: ICS export URL from calendar sharing
|
||||
3. **Apple Calendar**: Published calendar ICS URL
|
||||
4. **Nextcloud**: Public/private calendar ICS export
|
||||
5. **Any CalDAV server**: Via ICS export endpoint
|
||||
|
||||
### ICS URL Examples
|
||||
- Google: `https://calendar.google.com/calendar/ical/{calendar_id}/private-{token}/basic.ics`
|
||||
- Outlook: `https://outlook.live.com/owa/calendar/{id}/calendar.ics`
|
||||
- Custom: `https://example.com/calendars/room-schedule.ics`
|
||||
|
||||
### Fetch Interval Configuration
|
||||
- 1 minute: For critical/high-activity rooms
|
||||
- 5 minutes (default): Balance of freshness and efficiency
|
||||
- 10 minutes: Standard meeting rooms
|
||||
- 30 minutes: Low-activity rooms
|
||||
- 1 hour: Rarely-used rooms or stable schedules
|
||||
3
server/.gitignore
vendored
3
server/.gitignore
vendored
@@ -176,8 +176,7 @@ artefacts/
|
||||
audio_*.wav
|
||||
|
||||
# ignore local database
|
||||
*.sqlite3
|
||||
*.db
|
||||
reflector.sqlite3
|
||||
data/
|
||||
|
||||
dump.rdb
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
UV_LINK_MODE=copy
|
||||
|
||||
# builder install base dependencies
|
||||
WORKDIR /tmp
|
||||
@@ -14,8 +13,8 @@ ENV PATH="/root/.local/bin/:$PATH"
|
||||
# install application dependencies
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml uv.lock README.md /app/
|
||||
RUN uv sync --compile-bytecode --locked
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
RUN touch README.md && env uv sync --compile-bytecode --locked
|
||||
|
||||
# pre-download nltk packages
|
||||
RUN uv run python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
@@ -27,15 +26,4 @@ COPY migrations /app/migrations
|
||||
COPY reflector /app/reflector
|
||||
WORKDIR /app
|
||||
|
||||
# Create symlink for libgomp if it doesn't exist (for ARM64 compatibility)
|
||||
RUN if [ "$(uname -m)" = "aarch64" ] && [ ! -f /usr/lib/libgomp.so.1 ]; then \
|
||||
LIBGOMP_PATH=$(find /app/.venv/lib -path "*/torch.libs/libgomp*.so.*" 2>/dev/null | head -n1); \
|
||||
if [ -n "$LIBGOMP_PATH" ]; then \
|
||||
ln -sf "$LIBGOMP_PATH" /usr/lib/libgomp.so.1; \
|
||||
fi \
|
||||
fi
|
||||
|
||||
# Pre-check just to make sure the image will not fail
|
||||
RUN uv run python -c "import silero_vad.model"
|
||||
|
||||
CMD ["./runserver.sh"]
|
||||
|
||||
@@ -40,5 +40,3 @@ uv run python -c "from reflector.pipelines.main_live_pipeline import task_pipeli
|
||||
```bash
|
||||
uv run python -c "from reflector.pipelines.main_live_pipeline import pipeline_post; pipeline_post(transcript_id='TRANSCRIPT_ID')"
|
||||
```
|
||||
|
||||
.
|
||||
|
||||
@@ -4,8 +4,7 @@ This repository hold an API for the GPU implementation of the Reflector API serv
|
||||
and use [Modal.com](https://modal.com)
|
||||
|
||||
- `reflector_diarizer.py` - Diarization API
|
||||
- `reflector_transcriber.py` - Transcription API (Whisper)
|
||||
- `reflector_transcriber_parakeet.py` - Transcription API (NVIDIA Parakeet)
|
||||
- `reflector_transcriber.py` - Transcription API
|
||||
- `reflector_translator.py` - Translation API
|
||||
|
||||
## Modal.com deployment
|
||||
@@ -20,10 +19,6 @@ $ modal deploy reflector_transcriber.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-web.modal.run
|
||||
|
||||
$ modal deploy reflector_transcriber_parakeet.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-transcriber-parakeet-web.modal.run
|
||||
|
||||
$ modal deploy reflector_llm.py
|
||||
...
|
||||
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
|
||||
@@ -73,86 +68,6 @@ Authorization: bearer <REFLECTOR_APIKEY>
|
||||
|
||||
### Transcription
|
||||
|
||||
#### Parakeet Transcriber (`reflector_transcriber_parakeet.py`)
|
||||
|
||||
NVIDIA Parakeet is a state-of-the-art ASR model optimized for real-time transcription with superior word-level timestamps.
|
||||
|
||||
**GPU Configuration:**
|
||||
- **A10G GPU** - Used for `/v1/audio/transcriptions` endpoint (small files, live transcription)
|
||||
- Higher concurrency (max_inputs=10)
|
||||
- Optimized for multiple small audio files
|
||||
- Supports batch processing for efficiency
|
||||
|
||||
- **L40S GPU** - Used for `/v1/audio/transcriptions-from-url` endpoint (large files)
|
||||
- Lower concurrency but more powerful processing
|
||||
- Optimized for single large audio files
|
||||
- VAD-based chunking for long-form audio
|
||||
|
||||
##### `/v1/audio/transcriptions` - Small file transcription
|
||||
|
||||
**request** (multipart/form-data)
|
||||
- `file` or `files[]` - audio file(s) to transcribe
|
||||
- `model` - model name (default: `nvidia/parakeet-tdt-0.6b-v2`)
|
||||
- `language` - language code (default: `en`)
|
||||
- `batch` - whether to use batch processing for multiple files (default: `true`)
|
||||
|
||||
**response**
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text",
|
||||
"words": [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0}
|
||||
],
|
||||
"filename": "audio.mp3"
|
||||
}
|
||||
```
|
||||
|
||||
For multiple files with batch=true:
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"filename": "audio1.mp3",
|
||||
"text": "transcribed text",
|
||||
"words": [...]
|
||||
},
|
||||
{
|
||||
"filename": "audio2.mp3",
|
||||
"text": "transcribed text",
|
||||
"words": [...]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
##### `/v1/audio/transcriptions-from-url` - Large file transcription
|
||||
|
||||
**request** (application/json)
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://example.com/audio.mp3",
|
||||
"model": "nvidia/parakeet-tdt-0.6b-v2",
|
||||
"language": "en",
|
||||
"timestamp_offset": 0.0
|
||||
}
|
||||
```
|
||||
|
||||
**response**
|
||||
```json
|
||||
{
|
||||
"text": "transcribed text from large file",
|
||||
"words": [
|
||||
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||
{"word": "world", "start": 0.5, "end": 1.0}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Supported file types:** mp3, mp4, mpeg, mpga, m4a, wav, webm
|
||||
|
||||
#### Whisper Transcriber (`reflector_transcriber.py`)
|
||||
|
||||
`POST /transcribe`
|
||||
|
||||
**request** (multipart/form-data)
|
||||
|
||||
@@ -4,80 +4,14 @@ Reflector GPU backend - diarizer
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Mapping, NewType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
import modal.gpu
|
||||
from modal import App, Image, Secret, asgi_app, enter, method
|
||||
from pydantic import BaseModel
|
||||
|
||||
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.1"
|
||||
MODEL_DIR = "/root/diarization_models"
|
||||
UPLOADS_PATH = "/uploads"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
|
||||
DiarizerUniqFilename = NewType("DiarizerUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
app = modal.App(name="reflector-diarizer")
|
||||
|
||||
# Volume for temporary file uploads
|
||||
upload_volume = modal.Volume.from_name("diarizer-uploads", create_if_missing=True)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[DiarizerUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
print(f"Checking audio file at: {audio_file_url}")
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
print(f"Downloading audio file from: {audio_file_url}")
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Download failed with status {response.status_code}: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to download audio file: {response.status_code}",
|
||||
)
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = DiarizerUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
print(f"Writing file to: {file_path} (size: {len(response.content)} bytes)")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
print(f"File saved as: {unique_filename}")
|
||||
return unique_filename, audio_suffix
|
||||
app = App(name="reflector-diarizer")
|
||||
|
||||
|
||||
def migrate_cache_llm():
|
||||
@@ -105,7 +39,7 @@ def download_pyannote_audio():
|
||||
|
||||
|
||||
diarizer_image = (
|
||||
modal.Image.debian_slim(python_version="3.10.8")
|
||||
Image.debian_slim(python_version="3.10.8")
|
||||
.pip_install(
|
||||
"pyannote.audio==3.1.0",
|
||||
"requests",
|
||||
@@ -121,8 +55,7 @@ diarizer_image = (
|
||||
"hf-transfer",
|
||||
)
|
||||
.run_function(
|
||||
download_pyannote_audio,
|
||||
secrets=[modal.Secret.from_name("hf_token")],
|
||||
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
|
||||
)
|
||||
.run_function(migrate_cache_llm)
|
||||
.env(
|
||||
@@ -137,60 +70,53 @@ diarizer_image = (
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A100",
|
||||
gpu=modal.gpu.A100(size="40GB"),
|
||||
timeout=60 * 30,
|
||||
scaledown_window=60,
|
||||
allow_concurrent_inputs=1,
|
||||
image=diarizer_image,
|
||||
volumes={UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
secrets=[
|
||||
modal.Secret.from_name("hf_token"),
|
||||
],
|
||||
)
|
||||
@modal.concurrent(max_inputs=1)
|
||||
class Diarizer:
|
||||
@modal.enter(snap=True)
|
||||
@enter()
|
||||
def enter(self):
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = "cuda" if self.use_gpu else "cpu"
|
||||
print(f"Using device: {self.device}")
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
PYANNOTE_MODEL_NAME,
|
||||
cache_dir=MODEL_DIR,
|
||||
use_auth_token=os.environ["HF_TOKEN"],
|
||||
PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
|
||||
@modal.method()
|
||||
def diarize(self, filename: str, timestamp: float = 0.0):
|
||||
@method()
|
||||
def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
|
||||
import tempfile
|
||||
|
||||
import torchaudio
|
||||
|
||||
upload_volume.reload()
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
|
||||
fp.write(audio_data)
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
print(f"Diarizing audio from: {file_path}")
|
||||
waveform, sample_rate = torchaudio.load(file_path)
|
||||
diarization = self.diarization_pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:]),
|
||||
}
|
||||
print("Diarizing audio")
|
||||
waveform, sample_rate = torchaudio.load(fp.name)
|
||||
diarization = self.diarization_pipeline(
|
||||
{"waveform": waveform, "sample_rate": sample_rate}
|
||||
)
|
||||
print("Diarization complete")
|
||||
return {"diarization": words}
|
||||
|
||||
words = []
|
||||
for diarization_segment, _, speaker in diarization.itertracks(
|
||||
yield_label=True
|
||||
):
|
||||
words.append(
|
||||
{
|
||||
"start": round(timestamp + diarization_segment.start, 3),
|
||||
"end": round(timestamp + diarization_segment.end, 3),
|
||||
"speaker": int(speaker[-2:]),
|
||||
}
|
||||
)
|
||||
print("Diarization complete")
|
||||
return {"diarization": words}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
@@ -201,18 +127,17 @@ class Diarizer:
|
||||
@app.function(
|
||||
timeout=60 * 10,
|
||||
scaledown_window=60 * 3,
|
||||
allow_concurrent_inputs=40,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={UPLOADS_PATH: upload_volume},
|
||||
image=diarizer_image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
@asgi_app()
|
||||
def web():
|
||||
import requests
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
diarizerstub = Diarizer()
|
||||
|
||||
@@ -228,26 +153,35 @@ def web():
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
def validate_audio_file(audio_file_url: str):
|
||||
# Check if the audio file exists
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail="The audio file does not exist.",
|
||||
)
|
||||
|
||||
class DiarizationResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/diarize", dependencies=[Depends(apikey_auth)])
|
||||
def diarize(audio_file_url: str, timestamp: float = 0.0) -> DiarizationResponse:
|
||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
@app.post(
|
||||
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
|
||||
)
|
||||
def diarize(
|
||||
audio_file_url: str, timestamp: float = 0.0
|
||||
) -> HTTPException | DiarizationResponse:
|
||||
# Currently the uploaded files are in mp3 format
|
||||
audio_suffix = "mp3"
|
||||
|
||||
try:
|
||||
func = diarizerstub.diarize.spawn(
|
||||
filename=unique_filename, timestamp=timestamp
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {unique_filename}: {e}")
|
||||
print("Downloading audio file")
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
print("Audio file downloaded successfully")
|
||||
|
||||
func = diarizerstub.diarize.spawn(
|
||||
audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
return app
|
||||
|
||||
@@ -1,622 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Mapping, NewType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import modal
|
||||
|
||||
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
|
||||
SUPPORTED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
|
||||
SAMPLERATE = 16000
|
||||
UPLOADS_PATH = "/uploads"
|
||||
CACHE_PATH = "/cache"
|
||||
VAD_CONFIG = {
|
||||
"max_segment_duration": 30.0,
|
||||
"batch_max_files": 10,
|
||||
"batch_max_duration": 5.0,
|
||||
"min_segment_duration": 0.02,
|
||||
"silence_padding": 0.5,
|
||||
"window_size": 512,
|
||||
}
|
||||
|
||||
ParakeetUniqFilename = NewType("ParakeetUniqFilename", str)
|
||||
AudioFileExtension = NewType("AudioFileExtension", str)
|
||||
|
||||
app = modal.App("reflector-transcriber-parakeet")
|
||||
|
||||
# Volume for caching model weights
|
||||
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
|
||||
# Volume for temporary file uploads
|
||||
upload_volume = modal.Volume.from_name("parakeet-uploads", create_if_missing=True)
|
||||
|
||||
image = (
|
||||
modal.Image.from_registry(
|
||||
"nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", add_python="3.12"
|
||||
)
|
||||
.env(
|
||||
{
|
||||
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
||||
"HF_HOME": "/cache",
|
||||
"DEBIAN_FRONTEND": "noninteractive",
|
||||
"CXX": "g++",
|
||||
"CC": "g++",
|
||||
}
|
||||
)
|
||||
.apt_install("ffmpeg")
|
||||
.pip_install(
|
||||
"hf_transfer==0.1.9",
|
||||
"huggingface_hub[hf-xet]==0.31.2",
|
||||
"nemo_toolkit[asr]==2.3.0",
|
||||
"cuda-python==12.8.0",
|
||||
"fastapi==0.115.12",
|
||||
"numpy<2",
|
||||
"librosa==0.10.1",
|
||||
"requests",
|
||||
"silero-vad==5.1.0",
|
||||
"torch",
|
||||
)
|
||||
.entrypoint([]) # silence chatty logs by container on start
|
||||
)
|
||||
|
||||
|
||||
def detect_audio_format(url: str, headers: Mapping[str, str]) -> AudioFileExtension:
|
||||
parsed_url = urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
for ext in SUPPORTED_FILE_EXTENSIONS:
|
||||
if url_path.lower().endswith(f".{ext}"):
|
||||
return AudioFileExtension(ext)
|
||||
|
||||
content_type = headers.get("content-type", "").lower()
|
||||
if "audio/mpeg" in content_type or "audio/mp3" in content_type:
|
||||
return AudioFileExtension("mp3")
|
||||
if "audio/wav" in content_type:
|
||||
return AudioFileExtension("wav")
|
||||
if "audio/mp4" in content_type:
|
||||
return AudioFileExtension("mp4")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported audio format for URL: {url}. "
|
||||
f"Supported extensions: {', '.join(SUPPORTED_FILE_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
|
||||
def download_audio_to_volume(
|
||||
audio_file_url: str,
|
||||
) -> tuple[ParakeetUniqFilename, AudioFileExtension]:
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
response = requests.head(audio_file_url, allow_redirects=True)
|
||||
if response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
response = requests.get(audio_file_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
audio_suffix = detect_audio_format(audio_file_url, response.headers)
|
||||
unique_filename = ParakeetUniqFilename(f"{uuid.uuid4()}.{audio_suffix}")
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
upload_volume.commit()
|
||||
return unique_filename, audio_suffix
|
||||
|
||||
|
||||
def pad_audio(audio_array, sample_rate: int = SAMPLERATE):
|
||||
"""Add 0.5 seconds of silence if audio is less than 500ms.
|
||||
|
||||
This is a workaround for a Parakeet bug where very short audio (<500ms) causes:
|
||||
ValueError: `char_offsets`: [] and `processed_tokens`: [157, 834, 834, 841]
|
||||
have to be of the same length
|
||||
|
||||
See: https://github.com/NVIDIA/NeMo/issues/8451
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
audio_duration = len(audio_array) / sample_rate
|
||||
if audio_duration < 0.5:
|
||||
silence_samples = int(sample_rate * 0.5)
|
||||
silence = np.zeros(silence_samples, dtype=np.float32)
|
||||
return np.concatenate([audio_array, silence])
|
||||
return audio_array
|
||||
|
||||
|
||||
@app.cls(
|
||||
gpu="A10G",
|
||||
timeout=600,
|
||||
scaledown_window=300,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
)
|
||||
@modal.concurrent(max_inputs=10)
|
||||
class TranscriberParakeetLive:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import nemo.collections.asr as nemo_asr
|
||||
|
||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||
device = next(self.model.parameters()).device
|
||||
print(f"Model is on device: {device}")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
import librosa
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
padded_audio = pad_audio(audio_array, sample_rate)
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
(output,) = self.model.transcribe([padded_audio], timestamps=True)
|
||||
|
||||
text = output.text.strip()
|
||||
words = [
|
||||
{
|
||||
"word": word_info["word"],
|
||||
"start": round(word_info["start"], 2),
|
||||
"end": round(word_info["end"], 2),
|
||||
}
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
return {"text": text, "words": words}
|
||||
|
||||
@modal.method()
|
||||
def transcribe_batch(
|
||||
self,
|
||||
filenames: list[str],
|
||||
):
|
||||
import librosa
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
results = []
|
||||
audio_arrays = []
|
||||
|
||||
# Load all audio files with padding
|
||||
for filename in filenames:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Batch file not found: {file_path}")
|
||||
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
padded_audio = pad_audio(audio_array, sample_rate)
|
||||
audio_arrays.append(padded_audio)
|
||||
|
||||
with self.lock:
|
||||
with NoStdStreams():
|
||||
outputs = self.model.transcribe(audio_arrays, timestamps=True)
|
||||
|
||||
# Process results for each file
|
||||
for i, (filename, output) in enumerate(zip(filenames, outputs)):
|
||||
text = output.text.strip()
|
||||
|
||||
words = [
|
||||
{
|
||||
"word": word_info["word"],
|
||||
"start": round(word_info["start"], 2),
|
||||
"end": round(word_info["end"], 2),
|
||||
}
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"text": text,
|
||||
"words": words,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# L40S class for file transcription (bigger files)
|
||||
@app.cls(
|
||||
gpu="L40S",
|
||||
timeout=900,
|
||||
image=image,
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
enable_memory_snapshot=True,
|
||||
experimental_options={"enable_gpu_snapshot": True},
|
||||
)
|
||||
class TranscriberParakeetFile:
|
||||
@modal.enter(snap=True)
|
||||
def enter(self):
|
||||
import nemo.collections.asr as nemo_asr
|
||||
import torch
|
||||
from silero_vad import load_silero_vad
|
||||
|
||||
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
|
||||
|
||||
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=MODEL_NAME)
|
||||
device = next(self.model.parameters()).device
|
||||
print(f"Model is on device: {device}")
|
||||
|
||||
torch.set_num_threads(1)
|
||||
self.vad_model = load_silero_vad(onnx=False)
|
||||
print("Silero VAD initialized")
|
||||
|
||||
@modal.method()
|
||||
def transcribe_segment(
|
||||
self,
|
||||
filename: str,
|
||||
timestamp_offset: float = 0.0,
|
||||
):
|
||||
import librosa
|
||||
import numpy as np
|
||||
from silero_vad import VADIterator
|
||||
|
||||
def load_and_convert_audio(file_path):
|
||||
audio_array, sample_rate = librosa.load(file_path, sr=SAMPLERATE, mono=True)
|
||||
return audio_array
|
||||
|
||||
def vad_segment_generator(audio_array):
|
||||
"""Generate speech segments using VAD with start/end sample indices"""
|
||||
vad_iterator = VADIterator(self.vad_model, sampling_rate=SAMPLERATE)
|
||||
window_size = VAD_CONFIG["window_size"]
|
||||
start = None
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(
|
||||
chunk, (0, window_size - len(chunk)), mode="constant"
|
||||
)
|
||||
|
||||
speech_dict = vad_iterator(chunk)
|
||||
if not speech_dict:
|
||||
continue
|
||||
|
||||
if "start" in speech_dict:
|
||||
start = speech_dict["start"]
|
||||
continue
|
||||
|
||||
if "end" in speech_dict and start is not None:
|
||||
end = speech_dict["end"]
|
||||
start_time = start / float(SAMPLERATE)
|
||||
end_time = end / float(SAMPLERATE)
|
||||
|
||||
# Extract the actual audio segment
|
||||
audio_segment = audio_array[start:end]
|
||||
|
||||
yield (start_time, end_time, audio_segment)
|
||||
start = None
|
||||
|
||||
vad_iterator.reset_states()
|
||||
|
||||
def vad_segment_filter(segments):
|
||||
"""Filter VAD segments by duration and chunk large segments"""
|
||||
min_dur = VAD_CONFIG["min_segment_duration"]
|
||||
max_dur = VAD_CONFIG["max_segment_duration"]
|
||||
|
||||
for start_time, end_time, audio_segment in segments:
|
||||
segment_duration = end_time - start_time
|
||||
|
||||
# Skip very small segments
|
||||
if segment_duration < min_dur:
|
||||
continue
|
||||
|
||||
# If segment is within max duration, yield as-is
|
||||
if segment_duration <= max_dur:
|
||||
yield (start_time, end_time, audio_segment)
|
||||
continue
|
||||
|
||||
# Chunk large segments into smaller pieces
|
||||
chunk_samples = int(max_dur * SAMPLERATE)
|
||||
current_start = start_time
|
||||
|
||||
for chunk_offset in range(0, len(audio_segment), chunk_samples):
|
||||
chunk_audio = audio_segment[
|
||||
chunk_offset : chunk_offset + chunk_samples
|
||||
]
|
||||
if len(chunk_audio) == 0:
|
||||
break
|
||||
|
||||
chunk_duration = len(chunk_audio) / float(SAMPLERATE)
|
||||
chunk_end = current_start + chunk_duration
|
||||
|
||||
# Only yield chunks that meet minimum duration
|
||||
if chunk_duration >= min_dur:
|
||||
yield (current_start, chunk_end, chunk_audio)
|
||||
|
||||
current_start = chunk_end
|
||||
|
||||
def batch_segments(segments, max_files=10, max_duration=5.0):
|
||||
batch = []
|
||||
batch_duration = 0.0
|
||||
|
||||
for start_time, end_time, audio_segment in segments:
|
||||
segment_duration = end_time - start_time
|
||||
|
||||
if segment_duration < VAD_CONFIG["silence_padding"]:
|
||||
silence_samples = int(
|
||||
(VAD_CONFIG["silence_padding"] - segment_duration) * SAMPLERATE
|
||||
)
|
||||
padding = np.zeros(silence_samples, dtype=np.float32)
|
||||
audio_segment = np.concatenate([audio_segment, padding])
|
||||
segment_duration = VAD_CONFIG["silence_padding"]
|
||||
|
||||
batch.append((start_time, end_time, audio_segment))
|
||||
batch_duration += segment_duration
|
||||
|
||||
if len(batch) >= max_files or batch_duration >= max_duration:
|
||||
yield batch
|
||||
batch = []
|
||||
batch_duration = 0.0
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def transcribe_batch(model, audio_segments):
|
||||
with NoStdStreams():
|
||||
outputs = model.transcribe(audio_segments, timestamps=True)
|
||||
return outputs
|
||||
|
||||
def emit_results(
|
||||
results,
|
||||
segments_info,
|
||||
batch_index,
|
||||
total_batches,
|
||||
):
|
||||
"""Yield transcribed text and word timings from model output, adjusting timestamps to absolute positions."""
|
||||
for i, (output, (start_time, end_time, _)) in enumerate(
|
||||
zip(results, segments_info)
|
||||
):
|
||||
text = output.text.strip()
|
||||
words = [
|
||||
{
|
||||
"word": word_info["word"],
|
||||
"start": round(
|
||||
word_info["start"] + start_time + timestamp_offset, 2
|
||||
),
|
||||
"end": round(
|
||||
word_info["end"] + start_time + timestamp_offset, 2
|
||||
),
|
||||
}
|
||||
for word_info in output.timestamp["word"]
|
||||
]
|
||||
|
||||
yield text, words
|
||||
|
||||
upload_volume.reload()
|
||||
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
audio_array = load_and_convert_audio(file_path)
|
||||
total_duration = len(audio_array) / float(SAMPLERATE)
|
||||
processed_duration = 0.0
|
||||
|
||||
all_text_parts = []
|
||||
all_words = []
|
||||
|
||||
raw_segments = vad_segment_generator(audio_array)
|
||||
filtered_segments = vad_segment_filter(raw_segments)
|
||||
batches = batch_segments(
|
||||
filtered_segments,
|
||||
VAD_CONFIG["batch_max_files"],
|
||||
VAD_CONFIG["batch_max_duration"],
|
||||
)
|
||||
|
||||
batch_index = 0
|
||||
total_batches = max(
|
||||
1, int(total_duration / VAD_CONFIG["batch_max_duration"]) + 1
|
||||
)
|
||||
|
||||
for batch in batches:
|
||||
batch_index += 1
|
||||
audio_segments = [seg[2] for seg in batch]
|
||||
results = transcribe_batch(self.model, audio_segments)
|
||||
|
||||
for text, words in emit_results(
|
||||
results,
|
||||
batch,
|
||||
batch_index,
|
||||
total_batches,
|
||||
):
|
||||
if not text:
|
||||
continue
|
||||
all_text_parts.append(text)
|
||||
all_words.extend(words)
|
||||
|
||||
processed_duration += sum(len(seg[2]) / float(SAMPLERATE) for seg in batch)
|
||||
|
||||
combined_text = " ".join(all_text_parts)
|
||||
return {"text": combined_text, "words": all_words}
|
||||
|
||||
|
||||
@app.function(
|
||||
scaledown_window=60,
|
||||
timeout=600,
|
||||
secrets=[
|
||||
modal.Secret.from_name("reflector-gpu"),
|
||||
],
|
||||
volumes={CACHE_PATH: model_cache, UPLOADS_PATH: upload_volume},
|
||||
image=image,
|
||||
)
|
||||
@modal.concurrent(max_inputs=40)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import (
|
||||
Body,
|
||||
Depends,
|
||||
FastAPI,
|
||||
Form,
|
||||
HTTPException,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
transcriber_live = TranscriberParakeetLive()
|
||||
transcriber_file = TranscriberParakeetFile()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
|
||||
if apikey == os.environ["REFLECTOR_GPU_APIKEY"]:
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
class TranscriptResponse(BaseModel):
|
||||
result: dict
|
||||
|
||||
@app.post("/v1/audio/transcriptions", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe(
|
||||
file: UploadFile = None,
|
||||
files: list[UploadFile] | None = None,
|
||||
model: str = Form(MODEL_NAME),
|
||||
language: str = Form("en"),
|
||||
batch: bool = Form(False),
|
||||
):
|
||||
# Parakeet only supports English
|
||||
if language != "en":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||
)
|
||||
# Handle both single file and multiple files
|
||||
if not file and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either 'file' or 'files' parameter is required"
|
||||
)
|
||||
if batch and not files:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Batch transcription requires 'files'"
|
||||
)
|
||||
|
||||
upload_files = [file] if file else files
|
||||
|
||||
# Upload files to volume
|
||||
uploaded_filenames = []
|
||||
for upload_file in upload_files:
|
||||
audio_suffix = upload_file.filename.split(".")[-1]
|
||||
assert audio_suffix in SUPPORTED_FILE_EXTENSIONS
|
||||
|
||||
# Generate unique filename
|
||||
unique_filename = f"{uuid.uuid4()}.{audio_suffix}"
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
|
||||
print(f"Writing file to: {file_path}")
|
||||
with open(file_path, "wb") as f:
|
||||
content = upload_file.file.read()
|
||||
f.write(content)
|
||||
|
||||
uploaded_filenames.append(unique_filename)
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
try:
|
||||
# Use A10G live transcriber for per-file transcription
|
||||
if batch and len(upload_files) > 1:
|
||||
# Use batch transcription
|
||||
func = transcriber_live.transcribe_batch.spawn(
|
||||
filenames=uploaded_filenames,
|
||||
)
|
||||
results = func.get()
|
||||
return {"results": results}
|
||||
|
||||
# Per-file transcription
|
||||
results = []
|
||||
for filename in uploaded_filenames:
|
||||
func = transcriber_live.transcribe_segment.spawn(
|
||||
filename=filename,
|
||||
)
|
||||
result = func.get()
|
||||
result["filename"] = filename
|
||||
results.append(result)
|
||||
|
||||
return {"results": results} if len(results) > 1 else results[0]
|
||||
|
||||
finally:
|
||||
for filename in uploaded_filenames:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
print(f"Error deleting {filename}: {e}")
|
||||
|
||||
upload_volume.commit()
|
||||
|
||||
@app.post("/v1/audio/transcriptions-from-url", dependencies=[Depends(apikey_auth)])
|
||||
def transcribe_from_url(
|
||||
audio_file_url: str = Body(
|
||||
..., description="URL of the audio file to transcribe"
|
||||
),
|
||||
model: str = Body(MODEL_NAME),
|
||||
language: str = Body("en", description="Language code (only 'en' supported)"),
|
||||
timestamp_offset: float = Body(0.0),
|
||||
):
|
||||
# Parakeet only supports English
|
||||
if language != "en":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Parakeet model only supports English. Got language='{language}'",
|
||||
)
|
||||
unique_filename, audio_suffix = download_audio_to_volume(audio_file_url)
|
||||
|
||||
try:
|
||||
func = transcriber_file.transcribe_segment.spawn(
|
||||
filename=unique_filename,
|
||||
timestamp_offset=timestamp_offset,
|
||||
)
|
||||
result = func.get()
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
file_path = f"{UPLOADS_PATH}/{unique_filename}"
|
||||
print(f"Deleting file: {file_path}")
|
||||
os.remove(file_path)
|
||||
upload_volume.commit()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {unique_filename}: {e}")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class NoStdStreams:
|
||||
def __init__(self):
|
||||
self.devnull = open(os.devnull, "w")
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout, self._stderr = sys.stdout, sys.stderr
|
||||
self._stdout.flush()
|
||||
self._stderr.flush()
|
||||
sys.stdout, sys.stderr = self.devnull, self.devnull
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
sys.stdout, sys.stderr = self._stdout, self._stderr
|
||||
self.devnull.close()
|
||||
@@ -1,64 +0,0 @@
|
||||
"""add_long_summary_to_search_vector
|
||||
|
||||
Revision ID: 0ab2d7ffaa16
|
||||
Revises: b1c33bd09963
|
||||
Create Date: 2025-08-15 13:27:52.680211
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0ab2d7ffaa16"
|
||||
down_revision: Union[str, None] = "b1c33bd09963"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the search vector column with long_summary included
|
||||
op.execute("""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')
|
||||
) STORED
|
||||
""")
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the updated search vector column and index
|
||||
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
|
||||
op.drop_column("transcript", "search_vector_en")
|
||||
|
||||
# Recreate the original search vector column without long_summary
|
||||
op.execute("""
|
||||
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
|
||||
GENERATED ALWAYS AS (
|
||||
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
|
||||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
|
||||
) STORED
|
||||
""")
|
||||
|
||||
# Recreate the GIN index for the search vector
|
||||
op.create_index(
|
||||
"idx_transcript_search_vector_en",
|
||||
"transcript",
|
||||
["search_vector_en"],
|
||||
postgresql_using="gin",
|
||||
)
|
||||
@@ -0,0 +1,53 @@
|
||||
"""remove_one_active_meeting_per_room_constraint
|
||||
|
||||
Revision ID: 6025e9b2bef2
|
||||
Revises: 9f5c78d352d6
|
||||
Create Date: 2025-08-18 18:45:44.418392
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6025e9b2bef2"
|
||||
down_revision: Union[str, None] = "9f5c78d352d6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove the unique constraint that prevents multiple active meetings per room
|
||||
# This is needed to support calendar integration with overlapping meetings
|
||||
# Check if index exists before trying to drop it
|
||||
from alembic import context
|
||||
|
||||
if context.get_context().dialect.name == "postgresql":
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_indexes WHERE indexname = 'idx_one_active_meeting_per_room'"
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.drop_index("idx_one_active_meeting_per_room", table_name="meeting")
|
||||
else:
|
||||
# For SQLite, just try to drop it
|
||||
try:
|
||||
op.drop_index("idx_one_active_meeting_per_room", table_name="meeting")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore the unique constraint
|
||||
op.create_index(
|
||||
"idx_one_active_meeting_per_room",
|
||||
"meeting",
|
||||
["room_id"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
sqlite_where=sa.text("is_active = 1"),
|
||||
)
|
||||
@@ -1,41 +0,0 @@
|
||||
"""add_search_optimization_indexes
|
||||
|
||||
Revision ID: b1c33bd09963
|
||||
Revises: 9f5c78d352d6
|
||||
Create Date: 2025-08-14 17:26:02.117408
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b1c33bd09963"
|
||||
down_revision: Union[str, None] = "9f5c78d352d6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add indexes for actual search filtering patterns used in frontend
|
||||
# Based on /browse page filters: room_id and source_kind
|
||||
|
||||
# Index for room_id + created_at (for room-specific searches with date ordering)
|
||||
op.create_index(
|
||||
"idx_transcript_room_id_created_at",
|
||||
"transcript",
|
||||
["room_id", "created_at"],
|
||||
if_not_exists=True,
|
||||
)
|
||||
|
||||
# Index for source_kind alone (actively used filter in frontend)
|
||||
op.create_index(
|
||||
"idx_transcript_source_kind", "transcript", ["source_kind"], if_not_exists=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the indexes in reverse order
|
||||
op.drop_index("idx_transcript_source_kind", "transcript", if_exists=True)
|
||||
op.drop_index("idx_transcript_room_id_created_at", "transcript", if_exists=True)
|
||||
@@ -0,0 +1,34 @@
|
||||
"""add_grace_period_fields_to_meeting
|
||||
|
||||
Revision ID: d4a1c446458c
|
||||
Revises: 6025e9b2bef2
|
||||
Create Date: 2025-08-18 18:50:37.768052
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d4a1c446458c"
|
||||
down_revision: Union[str, None] = "6025e9b2bef2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add fields to track when participants left for grace period logic
|
||||
op.add_column(
|
||||
"meeting", sa.Column("last_participant_left_at", sa.DateTime(timezone=True))
|
||||
)
|
||||
op.add_column(
|
||||
"meeting",
|
||||
sa.Column("grace_period_minutes", sa.Integer, server_default=sa.text("15")),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("meeting", "grace_period_minutes")
|
||||
op.drop_column("meeting", "last_participant_left_at")
|
||||
@@ -32,6 +32,7 @@ dependencies = [
|
||||
"redis>=5.0.1",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.6",
|
||||
"faster-whisper>=0.10.0",
|
||||
"transformers>=4.36.2",
|
||||
"jsonschema>=4.23.0",
|
||||
"openai>=1.59.7",
|
||||
@@ -40,6 +41,7 @@ dependencies = [
|
||||
"llama-index-llms-openai-like>=0.4.0",
|
||||
"pytest-env>=1.1.5",
|
||||
"webvtt-py>=0.5.0",
|
||||
"icalendar>=6.0.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -56,7 +58,6 @@ tests = [
|
||||
"httpx-ws>=0.4.1",
|
||||
"pytest-httpx>=0.23.1",
|
||||
"pytest-celery>=0.0.0",
|
||||
"pytest-recording>=0.13.4",
|
||||
"pytest-docker>=3.2.3",
|
||||
"asgi-lifespan>=2.1.0",
|
||||
]
|
||||
@@ -67,15 +68,6 @@ evaluation = [
|
||||
"tqdm>=4.66.0",
|
||||
"pydantic>=2.1.1",
|
||||
]
|
||||
local = [
|
||||
"pyannote-audio>=3.3.2",
|
||||
"faster-whisper>=0.10.0",
|
||||
]
|
||||
silero-vad = [
|
||||
"silero-vad>=5.1.2",
|
||||
"torch>=2.8.0",
|
||||
"torchaudio>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
default-groups = [
|
||||
@@ -83,21 +75,6 @@ default-groups = [
|
||||
"tests",
|
||||
"aws",
|
||||
"evaluation",
|
||||
"local",
|
||||
"silero-vad"
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu" },
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@@ -118,9 +95,6 @@ DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_t
|
||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
markers = [
|
||||
"gpu_modal: mark test to run only with GPU Modal endpoints (deselect with '-m \"not gpu_modal\"')",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
|
||||
@@ -24,6 +24,7 @@ def get_database() -> databases.Database:
|
||||
|
||||
|
||||
# import models
|
||||
import reflector.db.calendar_events # noqa
|
||||
import reflector.db.meetings # noqa
|
||||
import reflector.db.recordings # noqa
|
||||
import reflector.db.rooms # noqa
|
||||
|
||||
193
server/reflector/db/calendar_events.py
Normal file
193
server/reflector/db/calendar_events.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.utils import generate_uuid4
|
||||
|
||||
calendar_events = sa.Table(
|
||||
"calendar_event",
|
||||
metadata,
|
||||
sa.Column("id", sa.String, primary_key=True),
|
||||
sa.Column(
|
||||
"room_id",
|
||||
sa.String,
|
||||
sa.ForeignKey("room.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("ics_uid", sa.Text, nullable=False),
|
||||
sa.Column("title", sa.Text),
|
||||
sa.Column("description", sa.Text),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("attendees", JSONB),
|
||||
sa.Column("location", sa.Text),
|
||||
sa.Column("ics_raw_data", sa.Text),
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("is_deleted", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
|
||||
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
|
||||
sa.Index(
|
||||
"idx_calendar_event_deleted",
|
||||
"is_deleted",
|
||||
postgresql_where=sa.text("NOT is_deleted"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
id: str = Field(default_factory=generate_uuid4)
|
||||
room_id: str
|
||||
ics_uid: str
|
||||
title: str | None = None
|
||||
description: str | None = None
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: list[dict[str, Any]] | None = None
|
||||
location: str | None = None
|
||||
ics_raw_data: str | None = None
|
||||
last_synced: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class CalendarEventController:
|
||||
async def get_by_room(
|
||||
self,
|
||||
room_id: str,
|
||||
include_deleted: bool = False,
|
||||
start_after: datetime | None = None,
|
||||
end_before: datetime | None = None,
|
||||
) -> list[CalendarEvent]:
|
||||
"""Get calendar events for a room."""
|
||||
query = calendar_events.select().where(calendar_events.c.room_id == room_id)
|
||||
|
||||
if not include_deleted:
|
||||
query = query.where(calendar_events.c.is_deleted == False)
|
||||
|
||||
if start_after:
|
||||
query = query.where(calendar_events.c.start_time >= start_after)
|
||||
|
||||
if end_before:
|
||||
query = query.where(calendar_events.c.end_time <= end_before)
|
||||
|
||||
query = query.order_by(calendar_events.c.start_time.asc())
|
||||
|
||||
results = await get_database().fetch_all(query)
|
||||
return [CalendarEvent(**result) for result in results]
|
||||
|
||||
async def get_upcoming(
|
||||
self, room_id: str, minutes_ahead: int = 30
|
||||
) -> list[CalendarEvent]:
|
||||
"""Get upcoming events for a room within the specified minutes."""
|
||||
now = datetime.now(timezone.utc)
|
||||
future_time = now + timedelta(minutes=minutes_ahead)
|
||||
|
||||
query = (
|
||||
calendar_events.select()
|
||||
.where(
|
||||
sa.and_(
|
||||
calendar_events.c.room_id == room_id,
|
||||
calendar_events.c.is_deleted == False,
|
||||
calendar_events.c.start_time >= now,
|
||||
calendar_events.c.start_time <= future_time,
|
||||
)
|
||||
)
|
||||
.order_by(calendar_events.c.start_time.asc())
|
||||
)
|
||||
|
||||
results = await get_database().fetch_all(query)
|
||||
return [CalendarEvent(**result) for result in results]
|
||||
|
||||
async def get_by_ics_uid(self, room_id: str, ics_uid: str) -> CalendarEvent | None:
|
||||
"""Get a calendar event by its ICS UID."""
|
||||
query = calendar_events.select().where(
|
||||
sa.and_(
|
||||
calendar_events.c.room_id == room_id,
|
||||
calendar_events.c.ics_uid == ics_uid,
|
||||
)
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
return CalendarEvent(**result) if result else None
|
||||
|
||||
async def upsert(self, event: CalendarEvent) -> CalendarEvent:
|
||||
"""Create or update a calendar event."""
|
||||
existing = await self.get_by_ics_uid(event.room_id, event.ics_uid)
|
||||
|
||||
if existing:
|
||||
# Update existing event
|
||||
event.id = existing.id
|
||||
event.created_at = existing.created_at
|
||||
event.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
query = (
|
||||
calendar_events.update()
|
||||
.where(calendar_events.c.id == existing.id)
|
||||
.values(**event.model_dump())
|
||||
)
|
||||
else:
|
||||
# Insert new event
|
||||
query = calendar_events.insert().values(**event.model_dump())
|
||||
|
||||
await get_database().execute(query)
|
||||
return event
|
||||
|
||||
async def soft_delete_missing(
|
||||
self, room_id: str, current_ics_uids: list[str]
|
||||
) -> int:
|
||||
"""Soft delete future events that are no longer in the calendar."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# First, get the IDs of events to delete
|
||||
select_query = calendar_events.select().where(
|
||||
sa.and_(
|
||||
calendar_events.c.room_id == room_id,
|
||||
calendar_events.c.start_time > now,
|
||||
calendar_events.c.is_deleted == False,
|
||||
calendar_events.c.ics_uid.notin_(current_ics_uids)
|
||||
if current_ics_uids
|
||||
else True,
|
||||
)
|
||||
)
|
||||
|
||||
to_delete = await get_database().fetch_all(select_query)
|
||||
delete_count = len(to_delete)
|
||||
|
||||
if delete_count > 0:
|
||||
# Now update them
|
||||
update_query = (
|
||||
calendar_events.update()
|
||||
.where(
|
||||
sa.and_(
|
||||
calendar_events.c.room_id == room_id,
|
||||
calendar_events.c.start_time > now,
|
||||
calendar_events.c.is_deleted == False,
|
||||
calendar_events.c.ics_uid.notin_(current_ics_uids)
|
||||
if current_ics_uids
|
||||
else True,
|
||||
)
|
||||
)
|
||||
.values(is_deleted=True, updated_at=now)
|
||||
)
|
||||
|
||||
await get_database().execute(update_query)
|
||||
|
||||
return delete_count
|
||||
|
||||
async def delete_by_room(self, room_id: str) -> int:
|
||||
"""Hard delete all events for a room (used when room is deleted)."""
|
||||
query = calendar_events.delete().where(calendar_events.c.room_id == room_id)
|
||||
result = await get_database().execute(query)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
# Add missing import
|
||||
from datetime import timedelta
|
||||
|
||||
calendar_events_controller = CalendarEventController()
|
||||
@@ -1,9 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from reflector.db import get_database, metadata
|
||||
from reflector.db.rooms import Room
|
||||
@@ -41,13 +42,16 @@ meetings = sa.Table(
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
sa.Index("idx_meeting_room_id", "room_id"),
|
||||
sa.Index(
|
||||
"idx_one_active_meeting_per_room",
|
||||
"room_id",
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
sa.Column(
|
||||
"calendar_event_id",
|
||||
sa.String,
|
||||
sa.ForeignKey("calendar_event.id", ondelete="SET NULL"),
|
||||
),
|
||||
sa.Column("calendar_metadata", JSONB),
|
||||
sa.Column("last_participant_left_at", sa.DateTime(timezone=True)),
|
||||
sa.Column("grace_period_minutes", sa.Integer, server_default=sa.text("15")),
|
||||
sa.Index("idx_meeting_room_id", "room_id"),
|
||||
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
|
||||
)
|
||||
|
||||
meeting_consent = sa.Table(
|
||||
@@ -85,6 +89,11 @@ class Meeting(BaseModel):
|
||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||
] = "automatic-2nd-participant"
|
||||
num_clients: int = 0
|
||||
is_active: bool = True
|
||||
calendar_event_id: str | None = None
|
||||
calendar_metadata: dict[str, Any] | None = None
|
||||
last_participant_left_at: datetime | None = None
|
||||
grace_period_minutes: int = 15
|
||||
|
||||
|
||||
class MeetingController:
|
||||
@@ -98,6 +107,8 @@ class MeetingController:
|
||||
end_date: datetime,
|
||||
user_id: str,
|
||||
room: Room,
|
||||
calendar_event_id: str | None = None,
|
||||
calendar_metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Create a new meeting
|
||||
@@ -115,6 +126,8 @@ class MeetingController:
|
||||
room_mode=room.room_mode,
|
||||
recording_type=room.recording_type,
|
||||
recording_trigger=room.recording_trigger,
|
||||
calendar_event_id=calendar_event_id,
|
||||
calendar_metadata=calendar_metadata,
|
||||
)
|
||||
query = meetings.insert().values(**meeting.model_dump())
|
||||
await get_database().execute(query)
|
||||
@@ -144,6 +157,7 @@ class MeetingController:
|
||||
async def get_active(self, room: Room, current_time: datetime) -> Meeting:
|
||||
"""
|
||||
Get latest active meeting for a room.
|
||||
For backward compatibility, returns the most recent active meeting.
|
||||
"""
|
||||
end_date = getattr(meetings.c, "end_date")
|
||||
query = (
|
||||
@@ -163,6 +177,47 @@ class MeetingController:
|
||||
|
||||
return Meeting(**result)
|
||||
|
||||
async def get_all_active_for_room(
|
||||
self, room: Room, current_time: datetime
|
||||
) -> list[Meeting]:
|
||||
"""
|
||||
Get all active meetings for a room.
|
||||
This supports multiple concurrent meetings per room.
|
||||
"""
|
||||
end_date = getattr(meetings.c, "end_date")
|
||||
query = (
|
||||
meetings.select()
|
||||
.where(
|
||||
sa.and_(
|
||||
meetings.c.room_id == room.id,
|
||||
meetings.c.end_date > current_time,
|
||||
meetings.c.is_active,
|
||||
)
|
||||
)
|
||||
.order_by(end_date.desc())
|
||||
)
|
||||
results = await get_database().fetch_all(query)
|
||||
return [Meeting(**result) for result in results]
|
||||
|
||||
async def get_active_by_calendar_event(
|
||||
self, room: Room, calendar_event_id: str, current_time: datetime
|
||||
) -> Meeting | None:
|
||||
"""
|
||||
Get active meeting for a specific calendar event.
|
||||
"""
|
||||
query = meetings.select().where(
|
||||
sa.and_(
|
||||
meetings.c.room_id == room.id,
|
||||
meetings.c.calendar_event_id == calendar_event_id,
|
||||
meetings.c.end_date > current_time,
|
||||
meetings.c.is_active,
|
||||
)
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Meeting(**result)
|
||||
|
||||
async def get_by_id(self, meeting_id: str, **kwargs) -> Meeting | None:
|
||||
"""
|
||||
Get a meeting by id
|
||||
@@ -190,6 +245,15 @@ class MeetingController:
|
||||
|
||||
return meeting
|
||||
|
||||
async def get_by_calendar_event(self, calendar_event_id: str) -> Meeting | None:
|
||||
query = meetings.select().where(
|
||||
meetings.c.calendar_event_id == calendar_event_id
|
||||
)
|
||||
result = await get_database().fetch_one(query)
|
||||
if not result:
|
||||
return None
|
||||
return Meeting(**result)
|
||||
|
||||
async def update_meeting(self, meeting_id: str, **kwargs):
|
||||
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
|
||||
await get_database().execute(query)
|
||||
|
||||
@@ -40,7 +40,15 @@ rooms = sqlalchemy.Table(
|
||||
sqlalchemy.Column(
|
||||
"is_shared", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column("ics_url", sqlalchemy.Text),
|
||||
sqlalchemy.Column("ics_fetch_interval", sqlalchemy.Integer, server_default="300"),
|
||||
sqlalchemy.Column(
|
||||
"ics_enabled", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||
),
|
||||
sqlalchemy.Column("ics_last_sync", sqlalchemy.DateTime(timezone=True)),
|
||||
sqlalchemy.Column("ics_last_etag", sqlalchemy.Text),
|
||||
sqlalchemy.Index("idx_room_is_shared", "is_shared"),
|
||||
sqlalchemy.Index("idx_room_ics_enabled", "ics_enabled"),
|
||||
)
|
||||
|
||||
|
||||
@@ -59,6 +67,11 @@ class Room(BaseModel):
|
||||
"none", "prompt", "automatic", "automatic-2nd-participant"
|
||||
] = "automatic-2nd-participant"
|
||||
is_shared: bool = False
|
||||
ics_url: str | None = None
|
||||
ics_fetch_interval: int = 300
|
||||
ics_enabled: bool = False
|
||||
ics_last_sync: datetime | None = None
|
||||
ics_last_etag: str | None = None
|
||||
|
||||
|
||||
class RoomController:
|
||||
@@ -107,6 +120,9 @@ class RoomController:
|
||||
recording_type: str,
|
||||
recording_trigger: str,
|
||||
is_shared: bool,
|
||||
ics_url: str | None = None,
|
||||
ics_fetch_interval: int = 300,
|
||||
ics_enabled: bool = False,
|
||||
):
|
||||
"""
|
||||
Add a new room
|
||||
@@ -122,6 +138,9 @@ class RoomController:
|
||||
recording_type=recording_type,
|
||||
recording_trigger=recording_trigger,
|
||||
is_shared=is_shared,
|
||||
ics_url=ics_url,
|
||||
ics_fetch_interval=ics_fetch_interval,
|
||||
ics_enabled=ics_enabled,
|
||||
)
|
||||
query = rooms.insert().values(**room.model_dump())
|
||||
try:
|
||||
|
||||
@@ -1,37 +1,24 @@
|
||||
"""Search functionality for transcripts and other entities."""
|
||||
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import Annotated, Any, Dict, Iterator
|
||||
from typing import Annotated, Any, Dict
|
||||
|
||||
import sqlalchemy
|
||||
import webvtt
|
||||
from fastapi import HTTPException
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
NonNegativeFloat,
|
||||
NonNegativeInt,
|
||||
ValidationError,
|
||||
constr,
|
||||
field_serializer,
|
||||
)
|
||||
from pydantic import BaseModel, Field, constr, field_serializer
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.rooms import rooms
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
from reflector.db.utils import is_postgresql
|
||||
from reflector.logger import logger
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 20
|
||||
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = NonNegativeInt(150)
|
||||
DEFAULT_MAX_SNIPPETS = NonNegativeInt(3)
|
||||
LONG_SUMMARY_MAX_SNIPPETS = 2
|
||||
DEFAULT_SNIPPET_MAX_LENGTH = 150
|
||||
DEFAULT_MAX_SNIPPETS = 3
|
||||
|
||||
SearchQueryBase = constr(min_length=0, strip_whitespace=True)
|
||||
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
|
||||
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
|
||||
SearchOffsetBase = Annotated[int, Field(ge=0)]
|
||||
SearchTotalBase = Annotated[int, Field(ge=0)]
|
||||
@@ -45,82 +32,6 @@ SearchTotal = Annotated[
|
||||
SearchTotalBase, Field(description="Total number of search results")
|
||||
]
|
||||
|
||||
WEBVTT_SPEC_HEADER = "WEBVTT"
|
||||
|
||||
WebVTTContent = Annotated[
|
||||
str,
|
||||
Field(min_length=len(WEBVTT_SPEC_HEADER), description="WebVTT content"),
|
||||
]
|
||||
|
||||
|
||||
class WebVTTProcessor:
|
||||
"""Stateless processor for WebVTT content operations."""
|
||||
|
||||
@staticmethod
|
||||
def parse(raw_content: str) -> WebVTTContent:
|
||||
"""Parse WebVTT content and return it as a string."""
|
||||
if not raw_content.startswith(WEBVTT_SPEC_HEADER):
|
||||
raise ValueError(f"Invalid WebVTT content, no header {WEBVTT_SPEC_HEADER}")
|
||||
return raw_content
|
||||
|
||||
@staticmethod
|
||||
def extract_text(webvtt_content: WebVTTContent) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except webvtt.errors.MalformedFileError as e:
|
||||
logger.warning(f"Malformed WebVTT content: {e}")
|
||||
return ""
|
||||
except (UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to decode WebVTT content: {e}")
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.error(
|
||||
f"WebVTT parsing error - unexpected format: {e}", exc_info=True
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error parsing WebVTT: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def generate_snippets(
|
||||
webvtt_content: WebVTTContent,
|
||||
query: str,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from WebVTT content."""
|
||||
return SnippetGenerator.generate(
|
||||
WebVTTProcessor.extract_text(webvtt_content),
|
||||
query,
|
||||
max_snippets=max_snippets,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnippetCandidate:
|
||||
"""Represents a candidate snippet with its position."""
|
||||
|
||||
_text: str
|
||||
start: NonNegativeInt
|
||||
_original_text_length: int
|
||||
|
||||
@property
|
||||
def end(self) -> NonNegativeInt:
|
||||
"""Calculate end position from start and raw text length."""
|
||||
return self.start + len(self._text)
|
||||
|
||||
def text(self) -> str:
|
||||
"""Get display text with ellipses added if needed."""
|
||||
result = self._text.strip()
|
||||
if self.start > 0:
|
||||
result = "..." + result
|
||||
if self.end < self._original_text_length:
|
||||
result = result + "..."
|
||||
return result
|
||||
|
||||
|
||||
class SearchParameters(BaseModel):
|
||||
"""Validated search parameters for full-text search."""
|
||||
@@ -130,7 +41,6 @@ class SearchParameters(BaseModel):
|
||||
offset: SearchOffset = 0
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
source_kind: SourceKind | None = None
|
||||
|
||||
|
||||
class SearchResultDB(BaseModel):
|
||||
@@ -154,18 +64,13 @@ class SearchResult(BaseModel):
|
||||
title: str | None = None
|
||||
user_id: str | None = None
|
||||
room_id: str | None = None
|
||||
room_name: str | None = None
|
||||
source_kind: SourceKind
|
||||
created_at: datetime
|
||||
status: str = Field(..., min_length=1)
|
||||
rank: float = Field(..., ge=0, le=1)
|
||||
duration: NonNegativeFloat | None = Field(..., description="Duration in seconds")
|
||||
duration: float | None = Field(..., ge=0, description="Duration in seconds")
|
||||
search_snippets: list[str] = Field(
|
||||
description="Text snippets around search matches"
|
||||
)
|
||||
total_match_count: NonNegativeInt = Field(
|
||||
default=0, description="Total number of matches found in the transcript"
|
||||
)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def serialize_datetime(self, dt: datetime) -> str:
|
||||
@@ -174,153 +79,84 @@ class SearchResult(BaseModel):
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class SnippetGenerator:
|
||||
"""Stateless generator for text snippets and match operations."""
|
||||
|
||||
@staticmethod
|
||||
def find_all_matches(text: str, query: str) -> Iterator[int]:
|
||||
"""Generate all match positions for a query in text."""
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in find_all_matches")
|
||||
return
|
||||
if not query:
|
||||
logger.warning("Empty query for search text in find_all_matches")
|
||||
return
|
||||
|
||||
text_lower = text.lower()
|
||||
query_lower = query.lower()
|
||||
start = 0
|
||||
prev_start = start
|
||||
while (pos := text_lower.find(query_lower, start)) != -1:
|
||||
yield pos
|
||||
start = pos + len(query_lower)
|
||||
if start <= prev_start:
|
||||
raise ValueError("panic! find_all_matches is not incremental")
|
||||
prev_start = start
|
||||
|
||||
@staticmethod
|
||||
def count_matches(text: str, query: str) -> NonNegativeInt:
|
||||
"""Count total number of matches for a query in text."""
|
||||
ZERO = NonNegativeInt(0)
|
||||
if not text:
|
||||
logger.warning("Empty text for search query in count_matches")
|
||||
return ZERO
|
||||
if not query:
|
||||
logger.warning("Empty query for search text in count_matches")
|
||||
return ZERO
|
||||
return NonNegativeInt(
|
||||
sum(1 for _ in SnippetGenerator.find_all_matches(text, query))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_snippet(
|
||||
text: str, match_pos: int, max_length: int = DEFAULT_SNIPPET_MAX_LENGTH
|
||||
) -> SnippetCandidate:
|
||||
"""Create a snippet from a match position."""
|
||||
snippet_start = NonNegativeInt(max(0, match_pos - SNIPPET_CONTEXT_LENGTH))
|
||||
snippet_end = min(len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH)
|
||||
|
||||
snippet_text = text[snippet_start:snippet_end]
|
||||
|
||||
return SnippetCandidate(
|
||||
_text=snippet_text, start=snippet_start, _original_text_length=len(text)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_non_overlapping(
|
||||
candidates: Iterator[SnippetCandidate],
|
||||
) -> Iterator[str]:
|
||||
"""Filter out overlapping snippets and return only display text."""
|
||||
last_end = 0
|
||||
for candidate in candidates:
|
||||
display_text = candidate.text()
|
||||
# it means that next overlapping snippets simply don't get included
|
||||
# it's fine as simplistic logic and users probably won't care much because they already have their search results just fin
|
||||
if candidate.start >= last_end and display_text:
|
||||
yield display_text
|
||||
last_end = candidate.end
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
text: str,
|
||||
query: str,
|
||||
max_length: NonNegativeInt = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from text."""
|
||||
if not text or not query:
|
||||
logger.warning("Empty text or query for generate_snippets")
|
||||
return []
|
||||
|
||||
candidates = (
|
||||
SnippetGenerator.create_snippet(text, pos, max_length)
|
||||
for pos in SnippetGenerator.find_all_matches(text, query)
|
||||
)
|
||||
filtered = SnippetGenerator.filter_non_overlapping(candidates)
|
||||
snippets = list(itertools.islice(filtered, max_snippets))
|
||||
|
||||
# Fallback to first word search if no full matches
|
||||
# it's another assumption: proper snippet logic generation is quite complicated and tied to db logic, so simplification is used here
|
||||
if not snippets and " " in query:
|
||||
first_word = query.split()[0]
|
||||
return SnippetGenerator.generate(text, first_word, max_length, max_snippets)
|
||||
|
||||
return snippets
|
||||
|
||||
@staticmethod
|
||||
def from_summary(
|
||||
summary: str,
|
||||
query: str,
|
||||
max_snippets: NonNegativeInt = LONG_SUMMARY_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate snippets from summary text."""
|
||||
return SnippetGenerator.generate(summary, query, max_snippets=max_snippets)
|
||||
|
||||
@staticmethod
|
||||
def combine_sources(
|
||||
summary: str | None,
|
||||
webvtt: WebVTTContent | None,
|
||||
query: str,
|
||||
max_total: NonNegativeInt = DEFAULT_MAX_SNIPPETS,
|
||||
) -> tuple[list[str], NonNegativeInt]:
|
||||
"""Combine snippets from multiple sources and return total match count.
|
||||
|
||||
Returns (snippets, total_match_count) tuple.
|
||||
|
||||
snippets can be empty for real in case of e.g. title match
|
||||
"""
|
||||
webvtt_matches = 0
|
||||
summary_matches = 0
|
||||
|
||||
if webvtt:
|
||||
webvtt_text = WebVTTProcessor.extract_text(webvtt)
|
||||
webvtt_matches = SnippetGenerator.count_matches(webvtt_text, query)
|
||||
|
||||
if summary:
|
||||
summary_matches = SnippetGenerator.count_matches(summary, query)
|
||||
|
||||
total_matches = NonNegativeInt(webvtt_matches + summary_matches)
|
||||
|
||||
summary_snippets = (
|
||||
SnippetGenerator.from_summary(summary, query) if summary else []
|
||||
)
|
||||
|
||||
if len(summary_snippets) >= max_total:
|
||||
return summary_snippets[:max_total], total_matches
|
||||
|
||||
remaining = max_total - len(summary_snippets)
|
||||
webvtt_snippets = (
|
||||
WebVTTProcessor.generate_snippets(webvtt, query, remaining)
|
||||
if webvtt
|
||||
else []
|
||||
)
|
||||
|
||||
return summary_snippets + webvtt_snippets, total_matches
|
||||
|
||||
|
||||
class SearchController:
|
||||
"""Controller for search operations across different entities."""
|
||||
|
||||
@staticmethod
|
||||
def _extract_webvtt_text(webvtt_content: str) -> str:
|
||||
"""Extract plain text from WebVTT content using webvtt library."""
|
||||
if not webvtt_content:
|
||||
return ""
|
||||
|
||||
try:
|
||||
buffer = StringIO(webvtt_content)
|
||||
vtt = webvtt.read_buffer(buffer)
|
||||
return " ".join(caption.text for caption in vtt if caption.text)
|
||||
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
|
||||
return ""
|
||||
except AttributeError as e:
|
||||
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _generate_snippets(
|
||||
text: str,
|
||||
q: SearchQuery,
|
||||
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
|
||||
max_snippets: int = DEFAULT_MAX_SNIPPETS,
|
||||
) -> list[str]:
|
||||
"""Generate multiple snippets around all occurrences of search term."""
|
||||
if not text or not q:
|
||||
return []
|
||||
|
||||
snippets = []
|
||||
lower_text = text.lower()
|
||||
search_lower = q.lower()
|
||||
|
||||
last_snippet_end = 0
|
||||
start_pos = 0
|
||||
|
||||
while len(snippets) < max_snippets:
|
||||
match_pos = lower_text.find(search_lower, start_pos)
|
||||
|
||||
if match_pos == -1:
|
||||
if not snippets and search_lower.split():
|
||||
first_word = search_lower.split()[0]
|
||||
match_pos = lower_text.find(first_word, start_pos)
|
||||
if match_pos == -1:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
|
||||
snippet_end = min(
|
||||
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if snippet_start < last_snippet_end:
|
||||
start_pos = match_pos + len(search_lower)
|
||||
continue
|
||||
|
||||
snippet = text[snippet_start:snippet_end]
|
||||
|
||||
if snippet_start > 0:
|
||||
snippet = "..." + snippet
|
||||
if snippet_end < len(text):
|
||||
snippet = snippet + "..."
|
||||
|
||||
snippet = snippet.strip()
|
||||
|
||||
if snippet:
|
||||
snippets.append(snippet)
|
||||
last_snippet_end = snippet_end
|
||||
|
||||
start_pos = match_pos + len(search_lower)
|
||||
if start_pos >= len(text):
|
||||
break
|
||||
|
||||
return snippets
|
||||
|
||||
@classmethod
|
||||
async def search_transcripts(
|
||||
cls, params: SearchParameters
|
||||
@@ -336,70 +172,39 @@ class SearchController:
|
||||
)
|
||||
return [], 0
|
||||
|
||||
base_columns = [
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
transcripts.c.long_summary,
|
||||
sqlalchemy.case(
|
||||
(
|
||||
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
|
||||
"Deleted Room",
|
||||
),
|
||||
else_=rooms.c.name,
|
||||
).label("room_name"),
|
||||
]
|
||||
|
||||
if params.query_text:
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
)
|
||||
rank_column = sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank")
|
||||
else:
|
||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||
|
||||
columns = base_columns + [rank_column]
|
||||
base_query = sqlalchemy.select(columns).select_from(
|
||||
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
|
||||
search_query = sqlalchemy.func.websearch_to_tsquery(
|
||||
"english", params.query_text
|
||||
)
|
||||
|
||||
if params.query_text:
|
||||
base_query = base_query.where(
|
||||
transcripts.c.search_vector_en.op("@@")(search_query)
|
||||
)
|
||||
base_query = sqlalchemy.select(
|
||||
[
|
||||
transcripts.c.id,
|
||||
transcripts.c.title,
|
||||
transcripts.c.created_at,
|
||||
transcripts.c.duration,
|
||||
transcripts.c.status,
|
||||
transcripts.c.user_id,
|
||||
transcripts.c.room_id,
|
||||
transcripts.c.source_kind,
|
||||
transcripts.c.webvtt,
|
||||
sqlalchemy.func.ts_rank(
|
||||
transcripts.c.search_vector_en,
|
||||
search_query,
|
||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||
).label("rank"),
|
||||
]
|
||||
).where(transcripts.c.search_vector_en.op("@@")(search_query))
|
||||
|
||||
if params.user_id:
|
||||
base_query = base_query.where(
|
||||
sqlalchemy.or_(
|
||||
transcripts.c.user_id == params.user_id, rooms.c.is_shared
|
||||
)
|
||||
)
|
||||
else:
|
||||
base_query = base_query.where(rooms.c.is_shared)
|
||||
base_query = base_query.where(transcripts.c.user_id == params.user_id)
|
||||
if params.room_id:
|
||||
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||
if params.source_kind:
|
||||
base_query = base_query.where(
|
||||
transcripts.c.source_kind == params.source_kind
|
||||
)
|
||||
|
||||
if params.query_text:
|
||||
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
||||
else:
|
||||
order_by = sqlalchemy.desc(transcripts.c.created_at)
|
||||
|
||||
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
||||
|
||||
query = (
|
||||
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
|
||||
.limit(params.limit)
|
||||
.offset(params.offset)
|
||||
)
|
||||
rs = await get_database().fetch_all(query)
|
||||
|
||||
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||
@@ -409,40 +214,18 @@ class SearchController:
|
||||
|
||||
def _process_result(r) -> SearchResult:
|
||||
r_dict: Dict[str, Any] = dict(r)
|
||||
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
||||
if webvtt_raw:
|
||||
webvtt = WebVTTProcessor.parse(webvtt_raw)
|
||||
else:
|
||||
webvtt = None
|
||||
long_summary: str | None = r_dict.pop("long_summary", None)
|
||||
room_name: str | None = r_dict.pop("room_name", None)
|
||||
webvtt: str | None = r_dict.pop("webvtt", None)
|
||||
db_result = SearchResultDB.model_validate(r_dict)
|
||||
|
||||
snippets, total_match_count = SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, params.query_text, DEFAULT_MAX_SNIPPETS
|
||||
)
|
||||
snippets = []
|
||||
if webvtt:
|
||||
plain_text = cls._extract_webvtt_text(webvtt)
|
||||
snippets = cls._generate_snippets(plain_text, params.query_text)
|
||||
|
||||
return SearchResult(
|
||||
**db_result.model_dump(),
|
||||
room_name=room_name,
|
||||
search_snippets=snippets,
|
||||
total_match_count=total_match_count,
|
||||
)
|
||||
|
||||
try:
|
||||
results = [_process_result(r) for r in rs]
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid search result data: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal search result data consistency error"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing search results: {e}", exc_info=True)
|
||||
raise
|
||||
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
|
||||
|
||||
results = [_process_result(r) for r in rs]
|
||||
return results, total
|
||||
|
||||
|
||||
search_controller = SearchController()
|
||||
webvtt_processor = WebVTTProcessor()
|
||||
snippet_generator = SnippetGenerator()
|
||||
|
||||
@@ -88,8 +88,6 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
|
||||
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
||||
)
|
||||
|
||||
# Add PostgreSQL-specific full-text search column
|
||||
@@ -101,8 +99,7 @@ if is_postgresql():
|
||||
TSVECTOR,
|
||||
sqlalchemy.Computed(
|
||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
|
||||
persisted=True,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
"""
|
||||
File-based processing pipeline
|
||||
==============================
|
||||
|
||||
Optimized pipeline for processing complete audio/video files.
|
||||
Uses parallel processing for transcription, diarization, and waveform generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
|
||||
from reflector.db.transcripts import (
|
||||
Transcript,
|
||||
transcripts_controller,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.pipelines.main_live_pipeline import PipelineMainBase, asynctask
|
||||
from reflector.processors import (
|
||||
AudioFileWriterProcessor,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
)
|
||||
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
|
||||
from reflector.processors.file_diarization import FileDiarizationInput
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
from reflector.processors.file_transcript import FileTranscriptInput
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerInput,
|
||||
TranscriptDiarizationAssemblerProcessor,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
Transcript as TranscriptType,
|
||||
)
|
||||
from reflector.settings import settings
|
||||
from reflector.storage import get_transcripts_storage
|
||||
|
||||
|
||||
class EmptyPipeline:
|
||||
"""Empty pipeline for processors that need a pipeline reference"""
|
||||
|
||||
def __init__(self, logger: structlog.BoundLogger):
|
||||
self.logger = logger
|
||||
|
||||
def get_pref(self, k, d=None):
|
||||
return d
|
||||
|
||||
async def emit(self, event):
|
||||
pass
|
||||
|
||||
|
||||
class PipelineMainFile(PipelineMainBase):
|
||||
"""
|
||||
Optimized file processing pipeline.
|
||||
Processes complete audio/video files with parallel execution.
|
||||
"""
|
||||
|
||||
logger: structlog.BoundLogger = None
|
||||
empty_pipeline = None
|
||||
|
||||
def __init__(self, transcript_id: str):
|
||||
super().__init__(transcript_id=transcript_id)
|
||||
self.logger = logger.bind(transcript_id=self.transcript_id)
|
||||
self.empty_pipeline = EmptyPipeline(logger=self.logger)
|
||||
|
||||
def _handle_gather_exceptions(self, results: list, operation: str) -> None:
|
||||
"""Handle exceptions from asyncio.gather with return_exceptions=True"""
|
||||
for i, result in enumerate(results):
|
||||
if not isinstance(result, Exception):
|
||||
continue
|
||||
self.logger.error(
|
||||
f"Error in {operation} (task {i}): {result}",
|
||||
transcript_id=self.transcript_id,
|
||||
exc_info=result,
|
||||
)
|
||||
|
||||
async def process(self, file_path: Path):
|
||||
"""Main entry point for file processing"""
|
||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
# Extract audio and write to transcript location
|
||||
audio_path = await self.extract_and_write_audio(file_path, transcript)
|
||||
|
||||
# Upload for processing
|
||||
audio_url = await self.upload_audio(audio_path, transcript)
|
||||
|
||||
# Run parallel processing
|
||||
await self.run_parallel_processing(
|
||||
audio_path,
|
||||
audio_url,
|
||||
transcript.source_language,
|
||||
transcript.target_language,
|
||||
)
|
||||
|
||||
self.logger.info("File pipeline complete")
|
||||
|
||||
async def extract_and_write_audio(
|
||||
self, file_path: Path, transcript: Transcript
|
||||
) -> Path:
|
||||
"""Extract audio from video if needed and write to transcript location as MP3"""
|
||||
self.logger.info(f"Processing audio file: {file_path}")
|
||||
|
||||
# Check if it's already audio-only
|
||||
container = av.open(str(file_path))
|
||||
has_video = len(container.streams.video) > 0
|
||||
container.close()
|
||||
|
||||
# Use AudioFileWriterProcessor to write MP3 to transcript location
|
||||
mp3_writer = AudioFileWriterProcessor(
|
||||
path=transcript.audio_mp3_filename,
|
||||
on_duration=self.on_duration,
|
||||
)
|
||||
|
||||
# Process audio frames and write to transcript location
|
||||
input_container = av.open(str(file_path))
|
||||
for frame in input_container.decode(audio=0):
|
||||
await mp3_writer.push(frame)
|
||||
|
||||
await mp3_writer.flush()
|
||||
input_container.close()
|
||||
|
||||
if has_video:
|
||||
self.logger.info(
|
||||
f"Extracted audio from video and saved to {transcript.audio_mp3_filename}"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Converted audio file and saved to {transcript.audio_mp3_filename}"
|
||||
)
|
||||
|
||||
return transcript.audio_mp3_filename
|
||||
|
||||
async def upload_audio(self, audio_path: Path, transcript: Transcript) -> str:
|
||||
"""Upload audio to storage for processing"""
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
if not storage:
|
||||
raise Exception(
|
||||
"Storage backend required for file processing. Configure TRANSCRIPT_STORAGE_* settings."
|
||||
)
|
||||
|
||||
self.logger.info("Uploading audio to storage")
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
storage_path = f"file_pipeline/{transcript.id}/audio.mp3"
|
||||
await storage.put_file(storage_path, audio_data)
|
||||
|
||||
audio_url = await storage.get_file_url(storage_path)
|
||||
|
||||
self.logger.info(f"Audio uploaded to {audio_url}")
|
||||
return audio_url
|
||||
|
||||
async def run_parallel_processing(
|
||||
self,
|
||||
audio_path: Path,
|
||||
audio_url: str,
|
||||
source_language: str,
|
||||
target_language: str,
|
||||
):
|
||||
"""Coordinate parallel processing of transcription, diarization, and waveform"""
|
||||
self.logger.info(
|
||||
"Starting parallel processing", transcript_id=self.transcript_id
|
||||
)
|
||||
|
||||
# Phase 1: Parallel processing of independent tasks
|
||||
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||
diarization_task = self.diarize_file(audio_url)
|
||||
waveform_task = self.generate_waveform(audio_path)
|
||||
|
||||
results = await asyncio.gather(
|
||||
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||
)
|
||||
|
||||
transcript_result = results[0]
|
||||
diarization_result = results[1]
|
||||
|
||||
# Handle errors - raise any exception that occurred
|
||||
self._handle_gather_exceptions(results, "parallel processing")
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
|
||||
# Phase 2: Assemble transcript with diarization
|
||||
self.logger.info(
|
||||
"Assembling transcript with diarization", transcript_id=self.transcript_id
|
||||
)
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
input_data = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript_result, diarization=diarization_result or []
|
||||
)
|
||||
|
||||
# Store result for retrieval
|
||||
diarized_transcript: Transcript | None = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal diarized_transcript
|
||||
diarized_transcript = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
|
||||
if not diarized_transcript:
|
||||
raise ValueError("No diarized transcript captured")
|
||||
|
||||
# Phase 3: Generate topics from diarized transcript
|
||||
self.logger.info("Generating topics", transcript_id=self.transcript_id)
|
||||
topics = await self.detect_topics(diarized_transcript, target_language)
|
||||
|
||||
# Phase 4: Generate title and summaries in parallel
|
||||
self.logger.info(
|
||||
"Generating title and summaries", transcript_id=self.transcript_id
|
||||
)
|
||||
results = await asyncio.gather(
|
||||
self.generate_title(topics),
|
||||
self.generate_summaries(topics),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
self._handle_gather_exceptions(results, "title and summary generation")
|
||||
|
||||
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
|
||||
"""Transcribe complete file"""
|
||||
processor = FileTranscriptAutoProcessor()
|
||||
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
|
||||
|
||||
# Store result for retrieval
|
||||
result: TranscriptType | None = None
|
||||
|
||||
async def capture_result(transcript):
|
||||
nonlocal result
|
||||
result = transcript
|
||||
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
|
||||
if not result:
|
||||
raise ValueError("No transcript captured")
|
||||
|
||||
return result
|
||||
|
||||
async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None:
|
||||
"""Get diarization for file"""
|
||||
if not settings.DIARIZATION_BACKEND:
|
||||
self.logger.info("Diarization disabled")
|
||||
return None
|
||||
|
||||
processor = FileDiarizationAutoProcessor()
|
||||
input_data = FileDiarizationInput(audio_url=audio_url)
|
||||
|
||||
# Store result for retrieval
|
||||
result = None
|
||||
|
||||
async def capture_result(diarization_output):
|
||||
nonlocal result
|
||||
result = diarization_output.diarization
|
||||
|
||||
try:
|
||||
processor.on(capture_result)
|
||||
await processor.push(input_data)
|
||||
await processor.flush()
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.error(f"Diarization failed: {e}")
|
||||
return None
|
||||
|
||||
async def generate_waveform(self, audio_path: Path):
|
||||
"""Generate and save waveform"""
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processor = AudioWaveformProcessor(
|
||||
audio_path=audio_path,
|
||||
waveform_path=transcript.audio_waveform_filename,
|
||||
on_waveform=self.on_waveform,
|
||||
)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def detect_topics(
|
||||
self, transcript: TranscriptType, target_language: str
|
||||
) -> list[TitleSummary]:
|
||||
"""Detect topics from complete transcript"""
|
||||
chunk_size = 300
|
||||
topics: list[TitleSummary] = []
|
||||
|
||||
async def on_topic(topic: TitleSummary):
|
||||
topics.append(topic)
|
||||
return await self.on_topic(topic)
|
||||
|
||||
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
|
||||
topic_detector.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for i in range(0, len(transcript.words), chunk_size):
|
||||
chunk_words = transcript.words[i : i + chunk_size]
|
||||
if not chunk_words:
|
||||
continue
|
||||
|
||||
chunk_transcript = TranscriptType(
|
||||
words=chunk_words, translation=transcript.translation
|
||||
)
|
||||
|
||||
await topic_detector.push(chunk_transcript)
|
||||
|
||||
await topic_detector.flush()
|
||||
return topics
|
||||
|
||||
async def generate_title(self, topics: list[TitleSummary]):
|
||||
"""Generate title from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for title generation")
|
||||
return
|
||||
|
||||
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for topic in topics:
|
||||
await processor.push(topic)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
async def generate_summaries(self, topics: list[TitleSummary]):
|
||||
"""Generate long and short summaries from topics"""
|
||||
if not topics:
|
||||
self.logger.warning("No topics for summary generation")
|
||||
return
|
||||
|
||||
transcript = await self.get_transcript()
|
||||
processor = TranscriptFinalSummaryProcessor(
|
||||
transcript=transcript,
|
||||
callback=self.on_long_summary,
|
||||
on_short_summary=self.on_short_summary,
|
||||
)
|
||||
processor.set_pipeline(self.empty_pipeline)
|
||||
|
||||
for topic in topics:
|
||||
await processor.push(topic)
|
||||
|
||||
await processor.flush()
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def task_pipeline_file_process(*, transcript_id: str):
|
||||
"""Celery task for file pipeline processing"""
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||
if not transcript:
|
||||
raise Exception(f"Transcript {transcript_id} not found")
|
||||
|
||||
# Find the file to process
|
||||
audio_file = next(transcript.data_path.glob("upload.*"), None)
|
||||
if not audio_file:
|
||||
audio_file = next(transcript.data_path.glob("audio.*"), None)
|
||||
|
||||
if not audio_file:
|
||||
raise Exception("No audio file found to process")
|
||||
|
||||
# Run file pipeline
|
||||
pipeline = PipelineMainFile(transcript_id=transcript_id)
|
||||
await pipeline.process(audio_file)
|
||||
@@ -147,18 +147,15 @@ class StrValue(BaseModel):
|
||||
|
||||
|
||||
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
|
||||
def __init__(self, transcript_id: str):
|
||||
super().__init__()
|
||||
self._lock = asyncio.Lock()
|
||||
self.transcript_id = transcript_id
|
||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||
self._ws_manager = None
|
||||
transcript_id: str
|
||||
ws_room_id: str | None = None
|
||||
ws_manager: WebsocketManager | None = None
|
||||
|
||||
@property
|
||||
def ws_manager(self) -> WebsocketManager:
|
||||
if self._ws_manager is None:
|
||||
self._ws_manager = get_ws_manager()
|
||||
return self._ws_manager
|
||||
def prepare(self):
|
||||
# prepare websocket
|
||||
self._lock = asyncio.Lock()
|
||||
self.ws_room_id = f"ts:{self.transcript_id}"
|
||||
self.ws_manager = get_ws_manager()
|
||||
|
||||
async def get_transcript(self) -> Transcript:
|
||||
# fetch the transcript
|
||||
@@ -358,6 +355,7 @@ class PipelineMainLive(PipelineMainBase):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
transcript = await self.get_transcript()
|
||||
|
||||
processors = [
|
||||
@@ -378,7 +376,6 @@ class PipelineMainLive(PipelineMainBase):
|
||||
pipeline.set_pref("audio:target_language", transcript.target_language)
|
||||
pipeline.logger.bind(transcript_id=transcript.id)
|
||||
pipeline.logger.info("Pipeline main live created")
|
||||
pipeline.describe()
|
||||
|
||||
return pipeline
|
||||
|
||||
@@ -397,6 +394,7 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
||||
async def create(self) -> Pipeline:
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
pipeline = Pipeline(
|
||||
AudioDiarizationAutoProcessor(callback=self.on_topic),
|
||||
)
|
||||
@@ -437,6 +435,8 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
||||
raise NotImplementedError
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
self.prepare()
|
||||
|
||||
# get transcript
|
||||
self._transcript = transcript = await self.get_transcript()
|
||||
|
||||
|
||||
@@ -18,14 +18,22 @@ During its lifecycle, it will emit the following status:
|
||||
import asyncio
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import Pipeline
|
||||
|
||||
PipelineMessage = TypeVar("PipelineMessage")
|
||||
|
||||
|
||||
class PipelineRunner(Generic[PipelineMessage]):
|
||||
def __init__(self):
|
||||
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
status: str = "idle"
|
||||
pipeline: Pipeline | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._task = None
|
||||
self._q_cmd = asyncio.Queue(maxsize=4096)
|
||||
self._ev_done = asyncio.Event()
|
||||
@@ -34,8 +42,6 @@ class PipelineRunner(Generic[PipelineMessage]):
|
||||
runner=id(self),
|
||||
runner_cls=self.__class__.__name__,
|
||||
)
|
||||
self.status = "idle"
|
||||
self.pipeline: Pipeline | None = None
|
||||
|
||||
async def create(self) -> Pipeline:
|
||||
"""
|
||||
|
||||
@@ -11,13 +11,6 @@ from .base import ( # noqa: F401
|
||||
Processor,
|
||||
ThreadedProcessor,
|
||||
)
|
||||
from .file_diarization import FileDiarizationProcessor # noqa: F401
|
||||
from .file_diarization_auto import FileDiarizationAutoProcessor # noqa: F401
|
||||
from .file_transcript import FileTranscriptProcessor # noqa: F401
|
||||
from .file_transcript_auto import FileTranscriptAutoProcessor # noqa: F401
|
||||
from .transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerProcessor, # noqa: F401
|
||||
)
|
||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||
|
||||
@@ -1,340 +1,28 @@
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from silero_vad import VADIterator, load_silero_vad
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
|
||||
|
||||
class AudioChunkerProcessor(Processor):
|
||||
"""
|
||||
Assemble audio frames into chunks with VAD-based speech detection
|
||||
Assemble audio frames into chunks
|
||||
"""
|
||||
|
||||
INPUT_TYPE = av.AudioFrame
|
||||
OUTPUT_TYPE = list[av.AudioFrame]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_frames=256,
|
||||
max_frames=1024,
|
||||
vad_threshold=0.5,
|
||||
use_onnx=False,
|
||||
min_frames=2,
|
||||
):
|
||||
def __init__(self, max_frames=256):
|
||||
super().__init__()
|
||||
self.frames: list[av.AudioFrame] = []
|
||||
self.block_frames = block_frames
|
||||
self.max_frames = max_frames
|
||||
self.vad_threshold = vad_threshold
|
||||
self.min_frames = min_frames
|
||||
|
||||
# Initialize Silero VAD
|
||||
self._init_vad(use_onnx)
|
||||
|
||||
def _init_vad(self, use_onnx=False):
|
||||
"""Initialize Silero VAD model"""
|
||||
try:
|
||||
torch.set_num_threads(1)
|
||||
self.vad_model = load_silero_vad(onnx=use_onnx)
|
||||
self.vad_iterator = VADIterator(self.vad_model, sampling_rate=16000)
|
||||
self.logger.info("Silero VAD initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize Silero VAD: {e}")
|
||||
self.vad_model = None
|
||||
self.vad_iterator = None
|
||||
|
||||
async def _push(self, data: av.AudioFrame):
|
||||
self.frames.append(data)
|
||||
# print("timestamp", data.pts * data.time_base * 1000)
|
||||
|
||||
# Check for speech segments every 32 frames (~1 second)
|
||||
if len(self.frames) >= 32 and len(self.frames) % 32 == 0:
|
||||
await self._process_block()
|
||||
|
||||
# Safety fallback - emit if we hit max frames
|
||||
elif len(self.frames) >= self.max_frames:
|
||||
self.logger.warning(
|
||||
f"AudioChunkerProcessor: Reached max frames ({self.max_frames}), "
|
||||
f"emitting first {self.max_frames // 2} frames"
|
||||
)
|
||||
frames_to_emit = self.frames[: self.max_frames // 2]
|
||||
self.frames = self.frames[self.max_frames // 2 :]
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
await self.emit(frames_to_emit)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring fallback segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
async def _process_block(self):
|
||||
# Need at least 32 frames for VAD detection (~1 second)
|
||||
if len(self.frames) < 32 or self.vad_iterator is None:
|
||||
return
|
||||
|
||||
# Processing block with current buffer size
|
||||
# print(f"Processing block: {len(self.frames)} frames in buffer")
|
||||
|
||||
try:
|
||||
# Convert frames to numpy array for VAD
|
||||
audio_array = self._frames_to_numpy(self.frames)
|
||||
|
||||
if audio_array is None:
|
||||
# Fallback: emit all frames if conversion failed
|
||||
frames_to_emit = self.frames[:]
|
||||
self.frames = []
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
await self.emit(frames_to_emit)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring conversion-failed segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
return
|
||||
|
||||
# Find complete speech segments in the buffer
|
||||
speech_end_frame = self._find_speech_segment_end(audio_array)
|
||||
|
||||
if speech_end_frame is None or speech_end_frame <= 0:
|
||||
# No speech found but buffer is getting large
|
||||
if len(self.frames) > 512:
|
||||
# Check if it's all silence and can be discarded
|
||||
# No speech segment found, buffer at {len(self.frames)} frames
|
||||
|
||||
# Could emit silence or discard old frames here
|
||||
# For now, keep first 256 frames and discard older silence
|
||||
if len(self.frames) > 768:
|
||||
self.logger.debug(
|
||||
f"Discarding {len(self.frames) - 256} old frames (likely silence)"
|
||||
)
|
||||
self.frames = self.frames[-256:]
|
||||
return
|
||||
|
||||
# Calculate segment timing information
|
||||
frames_to_emit = self.frames[:speech_end_frame]
|
||||
|
||||
# Get timing from av.AudioFrame
|
||||
if frames_to_emit:
|
||||
first_frame = frames_to_emit[0]
|
||||
last_frame = frames_to_emit[-1]
|
||||
sample_rate = first_frame.sample_rate
|
||||
|
||||
# Calculate duration
|
||||
total_samples = sum(f.samples for f in frames_to_emit)
|
||||
duration_seconds = total_samples / sample_rate if sample_rate > 0 else 0
|
||||
|
||||
# Get timestamps if available
|
||||
start_time = (
|
||||
first_frame.pts * first_frame.time_base if first_frame.pts else 0
|
||||
)
|
||||
end_time = (
|
||||
last_frame.pts * last_frame.time_base if last_frame.pts else 0
|
||||
)
|
||||
|
||||
# Convert to HH:MM:SS format for logging
|
||||
def format_time(seconds):
|
||||
if not seconds:
|
||||
return "00:00:00"
|
||||
total_seconds = int(float(seconds))
|
||||
hours = total_seconds // 3600
|
||||
minutes = (total_seconds % 3600) // 60
|
||||
secs = total_seconds % 60
|
||||
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||
|
||||
start_formatted = format_time(start_time)
|
||||
end_formatted = format_time(end_time)
|
||||
|
||||
# Keep remaining frames for next processing
|
||||
remaining_after = len(self.frames) - speech_end_frame
|
||||
|
||||
# Single structured log line
|
||||
self.logger.info(
|
||||
"Speech segment found",
|
||||
start=start_formatted,
|
||||
end=end_formatted,
|
||||
frames=speech_end_frame,
|
||||
duration=round(duration_seconds, 2),
|
||||
buffer_before=len(self.frames),
|
||||
remaining=remaining_after,
|
||||
)
|
||||
|
||||
# Keep remaining frames for next processing
|
||||
self.frames = self.frames[speech_end_frame:]
|
||||
|
||||
# Filter out segments with too few frames
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
await self.emit(frames_to_emit)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in VAD processing: {e}")
|
||||
# Fallback to simple chunking
|
||||
if len(self.frames) >= self.block_frames:
|
||||
frames_to_emit = self.frames[: self.block_frames]
|
||||
self.frames = self.frames[self.block_frames :]
|
||||
if len(frames_to_emit) >= self.min_frames:
|
||||
await self.emit(frames_to_emit)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring exception-fallback segment with {len(frames_to_emit)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
|
||||
def _frames_to_numpy(self, frames: list[av.AudioFrame]) -> Optional[np.ndarray]:
|
||||
"""Convert av.AudioFrame list to numpy array for VAD processing"""
|
||||
if not frames:
|
||||
return None
|
||||
|
||||
try:
|
||||
first_frame = frames[0]
|
||||
original_sample_rate = first_frame.sample_rate
|
||||
|
||||
audio_data = []
|
||||
for frame in frames:
|
||||
frame_array = frame.to_ndarray()
|
||||
|
||||
# Handle stereo -> mono conversion
|
||||
if len(frame_array.shape) == 2 and frame_array.shape[0] > 1:
|
||||
frame_array = np.mean(frame_array, axis=0)
|
||||
elif len(frame_array.shape) == 2:
|
||||
frame_array = frame_array.flatten()
|
||||
|
||||
audio_data.append(frame_array)
|
||||
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
combined_audio = np.concatenate(audio_data)
|
||||
|
||||
# Resample from 48kHz to 16kHz if needed
|
||||
if original_sample_rate != 16000:
|
||||
combined_audio = self._resample_audio(
|
||||
combined_audio, original_sample_rate, 16000
|
||||
)
|
||||
|
||||
# Ensure float32 format
|
||||
if combined_audio.dtype == np.int16:
|
||||
# Normalize int16 audio to float32 in range [-1.0, 1.0]
|
||||
combined_audio = combined_audio.astype(np.float32) / 32768.0
|
||||
elif combined_audio.dtype != np.float32:
|
||||
combined_audio = combined_audio.astype(np.float32)
|
||||
|
||||
return combined_audio
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error converting frames to numpy: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _resample_audio(
|
||||
self, audio: np.ndarray, from_sr: int, to_sr: int
|
||||
) -> np.ndarray:
|
||||
"""Simple linear resampling from from_sr to to_sr"""
|
||||
if from_sr == to_sr:
|
||||
return audio
|
||||
|
||||
try:
|
||||
# Simple linear interpolation resampling
|
||||
ratio = to_sr / from_sr
|
||||
new_length = int(len(audio) * ratio)
|
||||
|
||||
# Create indices for interpolation
|
||||
old_indices = np.linspace(0, len(audio) - 1, new_length)
|
||||
resampled = np.interp(old_indices, np.arange(len(audio)), audio)
|
||||
|
||||
return resampled.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Resampling error", exc_info=e)
|
||||
# Fallback: simple decimation/repetition
|
||||
if from_sr > to_sr:
|
||||
# Downsample by taking every nth sample
|
||||
step = from_sr // to_sr
|
||||
return audio[::step]
|
||||
else:
|
||||
# Upsample by repeating samples
|
||||
repeat = to_sr // from_sr
|
||||
return np.repeat(audio, repeat)
|
||||
|
||||
def _find_speech_segment_end(self, audio_array: np.ndarray) -> Optional[int]:
|
||||
"""Find complete speech segments and return frame index at segment end"""
|
||||
if self.vad_iterator is None or len(audio_array) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Process audio in 512-sample windows for VAD
|
||||
window_size = 512
|
||||
min_silence_windows = 3 # Require 3 windows of silence after speech
|
||||
|
||||
# Track speech state
|
||||
in_speech = False
|
||||
speech_start = None
|
||||
speech_end = None
|
||||
silence_count = 0
|
||||
|
||||
for i in range(0, len(audio_array), window_size):
|
||||
chunk = audio_array[i : i + window_size]
|
||||
if len(chunk) < window_size:
|
||||
chunk = np.pad(chunk, (0, window_size - len(chunk)))
|
||||
|
||||
# Detect if this window has speech
|
||||
speech_dict = self.vad_iterator(chunk, return_seconds=True)
|
||||
|
||||
# VADIterator returns dict with 'start' and 'end' when speech segments are detected
|
||||
if speech_dict:
|
||||
if not in_speech:
|
||||
# Speech started
|
||||
speech_start = i
|
||||
in_speech = True
|
||||
# Debug: print(f"Speech START at sample {i}, VAD: {speech_dict}")
|
||||
silence_count = 0 # Reset silence counter
|
||||
continue
|
||||
|
||||
if not in_speech:
|
||||
continue
|
||||
|
||||
# We're in speech but found silence
|
||||
silence_count += 1
|
||||
if silence_count < min_silence_windows:
|
||||
continue
|
||||
|
||||
# Found end of speech segment
|
||||
speech_end = i - (min_silence_windows - 1) * window_size
|
||||
# Debug: print(f"Speech END at sample {speech_end}")
|
||||
|
||||
# Convert sample position to frame index
|
||||
samples_per_frame = self.frames[0].samples if self.frames else 1024
|
||||
# Account for resampling: we process at 16kHz but frames might be 48kHz
|
||||
resample_ratio = 48000 / 16000 # 3x
|
||||
actual_sample_pos = int(speech_end * resample_ratio)
|
||||
frame_index = actual_sample_pos // samples_per_frame
|
||||
|
||||
# Ensure we don't exceed buffer
|
||||
frame_index = min(frame_index, len(self.frames))
|
||||
return frame_index
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error finding speech segment: {e}")
|
||||
return None
|
||||
if len(self.frames) >= self.max_frames:
|
||||
await self.flush()
|
||||
|
||||
async def _flush(self):
|
||||
frames = self.frames[:]
|
||||
self.frames = []
|
||||
if frames:
|
||||
if len(frames) >= self.min_frames:
|
||||
await self.emit(frames)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Ignoring flush segment with {len(frames)} frames "
|
||||
f"(< {self.min_frames} minimum)"
|
||||
)
|
||||
await self.emit(frames)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
Word,
|
||||
)
|
||||
@@ -38,21 +37,18 @@ class AudioDiarizationProcessor(Processor):
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def assign_speaker(cls, words: list[Word], diarization: list[DiarizationSegment]):
|
||||
cls._diarization_remove_overlap(diarization)
|
||||
cls._diarization_remove_segment_without_words(words, diarization)
|
||||
cls._diarization_merge_same_speaker(diarization)
|
||||
cls._diarization_assign_speaker(words, diarization)
|
||||
def assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
self._diarization_remove_overlap(diarization)
|
||||
self._diarization_remove_segment_without_words(words, diarization)
|
||||
self._diarization_merge_same_speaker(words, diarization)
|
||||
self._diarization_assign_speaker(words, diarization)
|
||||
|
||||
@staticmethod
|
||||
def iter_words_from_topics(topics: list[TitleSummary]):
|
||||
def iter_words_from_topics(self, topics: TitleSummary):
|
||||
for topic in topics:
|
||||
for word in topic.transcript.words:
|
||||
yield word
|
||||
|
||||
@staticmethod
|
||||
def is_word_continuation(word_prev, word):
|
||||
def is_word_continuation(self, word_prev, word):
|
||||
"""
|
||||
Return True if the word is a continuation of the previous word
|
||||
by checking if the previous word is ending with a punctuation
|
||||
@@ -65,8 +61,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _diarization_remove_overlap(diarization: list[DiarizationSegment]):
|
||||
def _diarization_remove_overlap(self, diarization: list[dict]):
|
||||
"""
|
||||
Remove overlap in diarization results
|
||||
|
||||
@@ -91,9 +86,8 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
@staticmethod
|
||||
def _diarization_remove_segment_without_words(
|
||||
words: list[Word], diarization: list[DiarizationSegment]
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
):
|
||||
"""
|
||||
Remove diarization segments without words
|
||||
@@ -122,8 +116,9 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
@staticmethod
|
||||
def _diarization_merge_same_speaker(diarization: list[DiarizationSegment]):
|
||||
def _diarization_merge_same_speaker(
|
||||
self, words: list[Word], diarization: list[dict]
|
||||
):
|
||||
"""
|
||||
Merge diarization contigous segments with the same speaker
|
||||
|
||||
@@ -140,10 +135,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
else:
|
||||
diarization_idx += 1
|
||||
|
||||
@classmethod
|
||||
def _diarization_assign_speaker(
|
||||
cls, words: list[Word], diarization: list[DiarizationSegment]
|
||||
):
|
||||
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
|
||||
"""
|
||||
Assign speaker to words based on diarization
|
||||
|
||||
@@ -151,7 +143,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
"""
|
||||
|
||||
word_idx = 0
|
||||
last_speaker = 0
|
||||
last_speaker = None
|
||||
for d in diarization:
|
||||
start = d["start"]
|
||||
end = d["end"]
|
||||
@@ -166,7 +158,7 @@ class AudioDiarizationProcessor(Processor):
|
||||
# If it's a continuation, assign with the last speaker
|
||||
is_continuation = False
|
||||
if word_idx > 0 and word_idx < len(words) - 1:
|
||||
is_continuation = cls.is_word_continuation(
|
||||
is_continuation = self.is_word_continuation(
|
||||
*words[word_idx - 1 : word_idx + 1]
|
||||
)
|
||||
if is_continuation:
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput, DiarizationSegment
|
||||
|
||||
|
||||
class AudioDiarizationPyannoteProcessor(AudioDiarizationProcessor):
|
||||
"""Local diarization processor using pyannote.audio library"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "pyannote/speaker-diarization-3.1",
|
||||
pyannote_auth_token: str | None = None,
|
||||
device: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.model_name = model_name
|
||||
self.auth_token = pyannote_auth_token or os.environ.get("HF_TOKEN")
|
||||
self.device = device
|
||||
|
||||
if device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.logger.info(f"Loading pyannote diarization model: {self.model_name}")
|
||||
self.diarization_pipeline = Pipeline.from_pretrained(
|
||||
self.model_name, use_auth_token=self.auth_token
|
||||
)
|
||||
self.diarization_pipeline.to(torch.device(self.device))
|
||||
self.logger.info(f"Diarization model loaded on device: {self.device}")
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput) -> list[DiarizationSegment]:
|
||||
try:
|
||||
# Load audio file (audio_url is assumed to be a local file path)
|
||||
self.logger.info(f"Loading local audio file: {data.audio_url}")
|
||||
waveform, sample_rate = torchaudio.load(data.audio_url)
|
||||
audio_input = {"waveform": waveform, "sample_rate": sample_rate}
|
||||
self.logger.info("Running speaker diarization")
|
||||
diarization = self.diarization_pipeline(audio_input)
|
||||
|
||||
# Convert pyannote diarization output to our format
|
||||
segments = []
|
||||
for segment, _, speaker in diarization.itertracks(yield_label=True):
|
||||
# Extract speaker number from label (e.g., "SPEAKER_00" -> 0)
|
||||
speaker_id = 0
|
||||
if speaker.startswith("SPEAKER_"):
|
||||
try:
|
||||
speaker_id = int(speaker.split("_")[-1])
|
||||
except (ValueError, IndexError):
|
||||
# Fallback to hash-based ID if parsing fails
|
||||
speaker_id = hash(speaker) % 1000
|
||||
|
||||
segments.append(
|
||||
{
|
||||
"start": round(segment.start, 3),
|
||||
"end": round(segment.end, 3),
|
||||
"speaker": speaker_id,
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(f"Diarization completed with {len(segments)} segments")
|
||||
return segments
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Diarization failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("pyannote", AudioDiarizationPyannoteProcessor)
|
||||
@@ -3,24 +3,11 @@ from time import monotonic_ns
|
||||
from uuid import uuid4
|
||||
|
||||
import av
|
||||
from av.audio.resampler import AudioResampler
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioFile
|
||||
|
||||
|
||||
def copy_frame(frame: av.AudioFrame) -> av.AudioFrame:
|
||||
frame_copy = frame.from_ndarray(
|
||||
frame.to_ndarray(),
|
||||
format=frame.format.name,
|
||||
layout=frame.layout.name,
|
||||
)
|
||||
frame_copy.sample_rate = frame.sample_rate
|
||||
frame_copy.pts = frame.pts
|
||||
frame_copy.time_base = frame.time_base
|
||||
return frame_copy
|
||||
|
||||
|
||||
class AudioMergeProcessor(Processor):
|
||||
"""
|
||||
Merge audio frame into a single file
|
||||
@@ -29,92 +16,37 @@ class AudioMergeProcessor(Processor):
|
||||
INPUT_TYPE = list[av.AudioFrame]
|
||||
OUTPUT_TYPE = AudioFile
|
||||
|
||||
def __init__(self, downsample_to_16k_mono: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.downsample_to_16k_mono = downsample_to_16k_mono
|
||||
|
||||
async def _push(self, data: list[av.AudioFrame]):
|
||||
if not data:
|
||||
return
|
||||
|
||||
# get audio information from first frame
|
||||
frame = data[0]
|
||||
original_channels = len(frame.layout.channels)
|
||||
original_sample_rate = frame.sample_rate
|
||||
original_sample_width = frame.format.bytes
|
||||
|
||||
# determine if we need processing
|
||||
needs_processing = self.downsample_to_16k_mono and (
|
||||
original_sample_rate != 16000 or original_channels != 1
|
||||
)
|
||||
|
||||
# determine output parameters
|
||||
if self.downsample_to_16k_mono:
|
||||
output_sample_rate = 16000
|
||||
output_channels = 1
|
||||
output_sample_width = 2 # 16-bit = 2 bytes
|
||||
else:
|
||||
output_sample_rate = original_sample_rate
|
||||
output_channels = original_channels
|
||||
output_sample_width = original_sample_width
|
||||
channels = len(frame.layout.channels)
|
||||
sample_rate = frame.sample_rate
|
||||
sample_width = frame.format.bytes
|
||||
|
||||
# create audio file
|
||||
uu = uuid4().hex
|
||||
fd = io.BytesIO()
|
||||
|
||||
if needs_processing:
|
||||
# Process with PyAV resampler
|
||||
out_container = av.open(fd, "w", format="wav")
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=16000)
|
||||
out_stream.layout = "mono"
|
||||
|
||||
# Create resampler if needed
|
||||
resampler = None
|
||||
if original_sample_rate != 16000 or original_channels != 1:
|
||||
resampler = AudioResampler(format="s16", layout="mono", rate=16000)
|
||||
|
||||
for frame in data:
|
||||
if resampler:
|
||||
# Resample and convert to mono
|
||||
# XXX for an unknown reason, if we don't use a copy of the frame, we get
|
||||
# Invalid Argumment from resample. Debugging indicate that when a previous processor
|
||||
# already used the frame (like AudioFileWriter), it make it invalid argument here.
|
||||
resampled_frames = resampler.resample(copy_frame(frame))
|
||||
for resampled_frame in resampled_frames:
|
||||
for packet in out_stream.encode(resampled_frame):
|
||||
out_container.mux(packet)
|
||||
else:
|
||||
# Direct encoding without resampling
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
for packet in out_stream.encode(None):
|
||||
out_container = av.open(fd, "w", format="wav")
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate)
|
||||
for frame in data:
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
out_container.close()
|
||||
else:
|
||||
# Use PyAV for original frames (no processing needed)
|
||||
out_container = av.open(fd, "w", format="wav")
|
||||
out_stream = out_container.add_stream("pcm_s16le", rate=output_sample_rate)
|
||||
out_stream.layout = "mono" if output_channels == 1 else frame.layout
|
||||
|
||||
for frame in data:
|
||||
for packet in out_stream.encode(frame):
|
||||
out_container.mux(packet)
|
||||
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
out_container.close()
|
||||
|
||||
for packet in out_stream.encode(None):
|
||||
out_container.mux(packet)
|
||||
out_container.close()
|
||||
fd.seek(0)
|
||||
|
||||
# emit audio file
|
||||
audiofile = AudioFile(
|
||||
name=f"{monotonic_ns()}-{uu}.wav",
|
||||
fd=fd,
|
||||
sample_rate=output_sample_rate,
|
||||
channels=output_channels,
|
||||
sample_width=output_sample_width,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
sample_width=sample_width,
|
||||
timestamp=data[0].pts * data[0].time_base,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,9 +12,6 @@ API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
@@ -24,9 +21,7 @@ from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
def __init__(
|
||||
self, modal_api_key: str | None = None, batch_enabled: bool = True, **kwargs
|
||||
):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__()
|
||||
if not settings.TRANSCRIPT_URL:
|
||||
raise Exception(
|
||||
@@ -35,126 +30,6 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
|
||||
self.timeout = settings.TRANSCRIPT_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
self.max_batch_duration = 10.0
|
||||
self.max_batch_files = 15
|
||||
self.batch_enabled = batch_enabled
|
||||
self.pending_files: List[AudioFile] = [] # Files waiting to be processed
|
||||
|
||||
@classmethod
|
||||
def _calculate_duration(cls, audio_file: AudioFile) -> float:
|
||||
"""Calculate audio duration in seconds from AudioFile metadata"""
|
||||
# Duration = total_samples / sample_rate
|
||||
# We need to estimate total samples from the file data
|
||||
import wave
|
||||
|
||||
try:
|
||||
# Try to read as WAV file to get duration
|
||||
audio_file.fd.seek(0)
|
||||
with wave.open(audio_file.fd, "rb") as wav_file:
|
||||
frames = wav_file.getnframes()
|
||||
sample_rate = wav_file.getframerate()
|
||||
duration = frames / sample_rate
|
||||
return duration
|
||||
except Exception:
|
||||
# Fallback: estimate from file size and audio parameters
|
||||
audio_file.fd.seek(0, 2) # Seek to end
|
||||
file_size = audio_file.fd.tell()
|
||||
audio_file.fd.seek(0) # Reset to beginning
|
||||
|
||||
# Estimate: file_size / (sample_rate * channels * sample_width)
|
||||
bytes_per_second = (
|
||||
audio_file.sample_rate
|
||||
* audio_file.channels
|
||||
* (audio_file.sample_width // 8)
|
||||
)
|
||||
estimated_duration = (
|
||||
file_size / bytes_per_second if bytes_per_second > 0 else 0
|
||||
)
|
||||
return max(0, estimated_duration)
|
||||
|
||||
def _create_batches(self, audio_files: List[AudioFile]) -> List[List[AudioFile]]:
|
||||
"""Group audio files into batches with maximum 30s total duration"""
|
||||
batches = []
|
||||
current_batch = []
|
||||
current_duration = 0.0
|
||||
|
||||
for audio_file in audio_files:
|
||||
duration = self._calculate_duration(audio_file)
|
||||
|
||||
# If adding this file exceeds max duration, start a new batch
|
||||
if current_duration + duration > self.max_batch_duration and current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = [audio_file]
|
||||
current_duration = duration
|
||||
else:
|
||||
current_batch.append(audio_file)
|
||||
current_duration += duration
|
||||
|
||||
# Add the last batch if not empty
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
|
||||
async def _transcript_batch(self, audio_files: List[AudioFile]) -> List[Transcript]:
|
||||
"""Transcribe a batch of audio files using the parakeet backend"""
|
||||
if not audio_files:
|
||||
return []
|
||||
|
||||
self.logger.debug(f"Batch transcribing {len(audio_files)} files")
|
||||
|
||||
# Prepare form data for batch request
|
||||
data = aiohttp.FormData()
|
||||
data.add_field("language", self.get_pref("audio:source_language", "en"))
|
||||
data.add_field("batch", "true")
|
||||
|
||||
for i, audio_file in enumerate(audio_files):
|
||||
audio_file.fd.seek(0)
|
||||
data.add_field(
|
||||
"files",
|
||||
audio_file.fd,
|
||||
filename=f"{audio_file.name}",
|
||||
content_type="audio/wav",
|
||||
)
|
||||
|
||||
# Make batch request
|
||||
headers = {"Authorization": f"Bearer {self.modal_api_key}"}
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.transcript_url}/audio/transcriptions",
|
||||
data=data,
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(
|
||||
f"Batch transcription failed: {response.status} {error_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# Process batch results
|
||||
transcripts = []
|
||||
results = result.get("results", [])
|
||||
|
||||
for i, (audio_file, file_result) in enumerate(zip(audio_files, results)):
|
||||
transcript = Transcript(
|
||||
words=[
|
||||
Word(
|
||||
text=word_info["word"],
|
||||
start=word_info["start"],
|
||||
end=word_info["end"],
|
||||
)
|
||||
for word_info in file_result.get("words", [])
|
||||
]
|
||||
)
|
||||
transcript.add_offset(audio_file.timestamp)
|
||||
transcripts.append(transcript)
|
||||
|
||||
return transcripts
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
async with AsyncOpenAI(
|
||||
@@ -187,96 +62,5 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
|
||||
return transcript
|
||||
|
||||
async def transcript_multiple(
|
||||
self, audio_files: List[AudioFile]
|
||||
) -> List[Transcript]:
|
||||
"""Transcribe multiple audio files using batching"""
|
||||
if len(audio_files) == 1:
|
||||
# Single file, use existing method
|
||||
return [await self._transcript(audio_files[0])]
|
||||
|
||||
# Create batches with max 30s duration each
|
||||
batches = self._create_batches(audio_files)
|
||||
|
||||
self.logger.debug(
|
||||
f"Processing {len(audio_files)} files in {len(batches)} batches"
|
||||
)
|
||||
|
||||
# Process all batches concurrently
|
||||
all_transcripts = []
|
||||
|
||||
for batch in batches:
|
||||
batch_transcripts = await self._transcript_batch(batch)
|
||||
all_transcripts.extend(batch_transcripts)
|
||||
|
||||
return all_transcripts
|
||||
|
||||
async def _push(self, data: AudioFile):
|
||||
"""Override _push to support batching"""
|
||||
if not self.batch_enabled:
|
||||
# Use parent implementation for single file processing
|
||||
return await super()._push(data)
|
||||
|
||||
# Add file to pending batch
|
||||
self.pending_files.append(data)
|
||||
self.logger.debug(
|
||||
f"Added file to batch: {data.name}, batch size: {len(self.pending_files)}"
|
||||
)
|
||||
|
||||
# Calculate total duration of pending files
|
||||
total_duration = sum(self._calculate_duration(f) for f in self.pending_files)
|
||||
|
||||
# Process batch if it reaches max duration or has multiple files ready for optimization
|
||||
should_process_batch = (
|
||||
total_duration >= self.max_batch_duration
|
||||
or len(self.pending_files) >= self.max_batch_files
|
||||
)
|
||||
|
||||
if should_process_batch:
|
||||
await self._process_pending_batch()
|
||||
|
||||
async def _process_pending_batch(self):
|
||||
"""Process all pending files as batches"""
|
||||
if not self.pending_files:
|
||||
return
|
||||
|
||||
self.logger.debug(f"Processing batch of {len(self.pending_files)} files")
|
||||
|
||||
try:
|
||||
# Create batches respecting duration limit
|
||||
batches = self._create_batches(self.pending_files)
|
||||
|
||||
# Process each batch
|
||||
for batch in batches:
|
||||
self.m_transcript_call.inc()
|
||||
try:
|
||||
with self.m_transcript.time():
|
||||
# Use batch transcription
|
||||
transcripts = await self._transcript_batch(batch)
|
||||
|
||||
self.m_transcript_success.inc()
|
||||
|
||||
# Emit each transcript
|
||||
for transcript in transcripts:
|
||||
if transcript:
|
||||
await self.emit(transcript)
|
||||
|
||||
except Exception:
|
||||
self.m_transcript_failure.inc()
|
||||
raise
|
||||
finally:
|
||||
# Release audio files
|
||||
for audio_file in batch:
|
||||
audio_file.release()
|
||||
|
||||
finally:
|
||||
# Clear pending files
|
||||
self.pending_files.clear()
|
||||
|
||||
async def _flush(self):
|
||||
"""Process any remaining files when flushing"""
|
||||
await self._process_pending_batch()
|
||||
await super()._flush()
|
||||
|
||||
|
||||
AudioTranscriptAutoProcessor.register("modal", AudioTranscriptModalProcessor)
|
||||
|
||||
@@ -173,7 +173,6 @@ class Processor(Emitter):
|
||||
except Exception:
|
||||
self.m_processor_failure.inc()
|
||||
self.logger.exception("Error in push")
|
||||
raise
|
||||
|
||||
async def flush(self):
|
||||
"""
|
||||
@@ -241,45 +240,33 @@ class ThreadedProcessor(Processor):
|
||||
self.INPUT_TYPE = processor.INPUT_TYPE
|
||||
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.queue = asyncio.Queue(maxsize=50)
|
||||
self.task: asyncio.Task | None = None
|
||||
self.queue = asyncio.Queue()
|
||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||
|
||||
def set_pipeline(self, pipeline: "Pipeline"):
|
||||
super().set_pipeline(pipeline)
|
||||
self.processor.set_pipeline(pipeline)
|
||||
|
||||
async def loop(self):
|
||||
try:
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
self.m_processor_queue.set(self.queue.qsize())
|
||||
with self.m_processor_queue_in_progress.track_inprogress():
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
self.m_processor_queue.set(self.queue.qsize())
|
||||
with self.m_processor_queue_in_progress.track_inprogress():
|
||||
try:
|
||||
if data is None:
|
||||
await self.processor.flush()
|
||||
break
|
||||
try:
|
||||
if data is None:
|
||||
await self.processor.flush()
|
||||
break
|
||||
try:
|
||||
await self.processor.push(data)
|
||||
except Exception:
|
||||
self.logger.error(
|
||||
f"Error in push {self.processor.__class__.__name__}"
|
||||
", continue"
|
||||
)
|
||||
finally:
|
||||
self.queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Crash in {self.__class__.__name__}: {e}", exc_info=e)
|
||||
|
||||
async def _ensure_task(self):
|
||||
if self.task is None:
|
||||
self.task = asyncio.get_running_loop().create_task(self.loop())
|
||||
|
||||
# XXX not doing a sleep here make the whole pipeline prior the thread
|
||||
# to be running without having a chance to work on the task here.
|
||||
await asyncio.sleep(0)
|
||||
await self.processor.push(data)
|
||||
except Exception:
|
||||
self.logger.error(
|
||||
f"Error in push {self.processor.__class__.__name__}"
|
||||
", continue"
|
||||
)
|
||||
finally:
|
||||
self.queue.task_done()
|
||||
|
||||
async def _push(self, data):
|
||||
await self._ensure_task()
|
||||
await self.queue.put(data)
|
||||
|
||||
async def _flush(self):
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import DiarizationSegment
|
||||
|
||||
|
||||
class FileDiarizationInput(BaseModel):
|
||||
"""Input for file diarization containing audio URL"""
|
||||
|
||||
audio_url: str
|
||||
|
||||
|
||||
class FileDiarizationOutput(BaseModel):
|
||||
"""Output for file diarization containing speaker segments"""
|
||||
|
||||
diarization: list[DiarizationSegment]
|
||||
|
||||
|
||||
class FileDiarizationProcessor(Processor):
|
||||
"""
|
||||
Diarize complete audio files from URL
|
||||
"""
|
||||
|
||||
INPUT_TYPE = FileDiarizationInput
|
||||
OUTPUT_TYPE = FileDiarizationOutput
|
||||
|
||||
async def _push(self, data: FileDiarizationInput):
|
||||
result = await self._diarize(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
raise NotImplementedError
|
||||
@@ -1,33 +0,0 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileDiarizationAutoProcessor(FileDiarizationProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.DIARIZATION_BACKEND
|
||||
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.file_diarization_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "DIARIZATION_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
@@ -1,57 +0,0 @@
|
||||
"""
|
||||
File diarization implementation using the GPU service from modal.com
|
||||
|
||||
API will be a POST request to DIARIZATION_URL:
|
||||
|
||||
```
|
||||
POST /diarize?audio_file_url=...×tamp=0
|
||||
Authorization: Bearer <modal_api_key>
|
||||
```
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.processors.file_diarization import (
|
||||
FileDiarizationInput,
|
||||
FileDiarizationOutput,
|
||||
FileDiarizationProcessor,
|
||||
)
|
||||
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileDiarizationModalProcessor(FileDiarizationProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.DIARIZATION_URL:
|
||||
raise Exception(
|
||||
"DIARIZATION_URL required to use FileDiarizationModalProcessor"
|
||||
)
|
||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||
self.file_timeout = settings.DIARIZATION_FILE_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
|
||||
async def _diarize(self, data: FileDiarizationInput):
|
||||
"""Get speaker diarization for file"""
|
||||
self.logger.info(f"Starting diarization from {data.audio_url}")
|
||||
|
||||
headers = {}
|
||||
if self.modal_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||
response = await client.post(
|
||||
self.diarization_url,
|
||||
headers=headers,
|
||||
params={
|
||||
"audio_file_url": data.audio_url,
|
||||
"timestamp": 0,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
diarization_data = response.json()["diarization"]
|
||||
|
||||
return FileDiarizationOutput(diarization=diarization_data)
|
||||
|
||||
|
||||
FileDiarizationAutoProcessor.register("modal", FileDiarizationModalProcessor)
|
||||
@@ -1,65 +0,0 @@
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import Transcript
|
||||
|
||||
|
||||
class FileTranscriptInput:
|
||||
"""Input for file transcription containing audio URL and language settings"""
|
||||
|
||||
def __init__(self, audio_url: str, language: str = "en"):
|
||||
self.audio_url = audio_url
|
||||
self.language = language
|
||||
|
||||
|
||||
class FileTranscriptProcessor(Processor):
|
||||
"""
|
||||
Transcript complete audio files from URL
|
||||
"""
|
||||
|
||||
INPUT_TYPE = FileTranscriptInput
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
m_transcript = Histogram(
|
||||
"file_transcript",
|
||||
"Time spent in FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_call = Counter(
|
||||
"file_transcript_call",
|
||||
"Number of calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_success = Counter(
|
||||
"file_transcript_success",
|
||||
"Number of successful calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
m_transcript_failure = Counter(
|
||||
"file_transcript_failure",
|
||||
"Number of failed calls to FileTranscript.transcript",
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
name = self.__class__.__name__
|
||||
self.m_transcript = self.m_transcript.labels(name)
|
||||
self.m_transcript_call = self.m_transcript_call.labels(name)
|
||||
self.m_transcript_success = self.m_transcript_success.labels(name)
|
||||
self.m_transcript_failure = self.m_transcript_failure.labels(name)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _push(self, data: FileTranscriptInput):
|
||||
try:
|
||||
self.m_transcript_call.inc()
|
||||
with self.m_transcript.time():
|
||||
result = await self._transcript(data)
|
||||
self.m_transcript_success.inc()
|
||||
if result:
|
||||
await self.emit(result)
|
||||
except Exception:
|
||||
self.m_transcript_failure.inc()
|
||||
raise
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
raise NotImplementedError
|
||||
@@ -1,32 +0,0 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileTranscriptAutoProcessor(FileTranscriptProcessor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
def __new__(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.TRANSCRIPT_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.file_transcript_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "TRANSCRIPT_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
@@ -1,74 +0,0 @@
|
||||
"""
|
||||
File transcription implementation using the GPU service from modal.com
|
||||
|
||||
API will be a POST request to TRANSCRIPT_URL:
|
||||
|
||||
```json
|
||||
{
|
||||
"audio_file_url": "https://...",
|
||||
"language": "en",
|
||||
"model": "parakeet-tdt-0.6b-v2",
|
||||
"batch": true
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.processors.file_transcript import (
|
||||
FileTranscriptInput,
|
||||
FileTranscriptProcessor,
|
||||
)
|
||||
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
|
||||
from reflector.processors.types import Transcript, Word
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
||||
def __init__(self, modal_api_key: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not settings.TRANSCRIPT_URL:
|
||||
raise Exception(
|
||||
"TRANSCRIPT_URL required to use FileTranscriptModalProcessor"
|
||||
)
|
||||
self.transcript_url = settings.TRANSCRIPT_URL
|
||||
self.file_timeout = settings.TRANSCRIPT_FILE_TIMEOUT
|
||||
self.modal_api_key = modal_api_key
|
||||
|
||||
async def _transcript(self, data: FileTranscriptInput):
|
||||
"""Send full file to Modal for transcription"""
|
||||
url = f"{self.transcript_url}/v1/audio/transcriptions-from-url"
|
||||
|
||||
self.logger.info(f"Starting file transcription from {data.audio_url}")
|
||||
|
||||
headers = {}
|
||||
if self.modal_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.modal_api_key}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.file_timeout) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json={
|
||||
"audio_file_url": data.audio_url,
|
||||
"language": data.language,
|
||||
"batch": True,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
words = [
|
||||
Word(
|
||||
text=word_info["word"],
|
||||
start=word_info["start"],
|
||||
end=word_info["end"],
|
||||
)
|
||||
for word_info in result.get("words", [])
|
||||
]
|
||||
|
||||
return Transcript(words=words)
|
||||
|
||||
|
||||
# Register with the auto processor
|
||||
FileTranscriptAutoProcessor.register("modal", FileTranscriptModalProcessor)
|
||||
@@ -1,45 +0,0 @@
|
||||
"""
|
||||
Processor to assemble transcript with diarization results
|
||||
"""
|
||||
|
||||
from reflector.processors.audio_diarization import AudioDiarizationProcessor
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import DiarizationSegment, Transcript
|
||||
|
||||
|
||||
class TranscriptDiarizationAssemblerInput:
|
||||
"""Input containing transcript and diarization data"""
|
||||
|
||||
def __init__(self, transcript: Transcript, diarization: list[DiarizationSegment]):
|
||||
self.transcript = transcript
|
||||
self.diarization = diarization
|
||||
|
||||
|
||||
class TranscriptDiarizationAssemblerProcessor(Processor):
|
||||
"""
|
||||
Assemble transcript with diarization results by applying speaker assignments
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TranscriptDiarizationAssemblerInput
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
async def _push(self, data: TranscriptDiarizationAssemblerInput):
|
||||
result = await self._assemble(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
|
||||
async def _assemble(self, data: TranscriptDiarizationAssemblerInput):
|
||||
"""Apply diarization to transcript words"""
|
||||
if not data.diarization:
|
||||
self.logger.info(
|
||||
"No diarization data provided, returning original transcript"
|
||||
)
|
||||
return data.transcript
|
||||
|
||||
# Reuse logic from AudioDiarizationProcessor
|
||||
processor = AudioDiarizationProcessor()
|
||||
words = data.transcript.words
|
||||
processor.assign_speaker(words, data.diarization)
|
||||
|
||||
self.logger.info(f"Applied diarization to {len(words)} words")
|
||||
return data.transcript
|
||||
@@ -2,22 +2,13 @@ import io
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
from typing import Annotated
|
||||
|
||||
from profanityfilter import ProfanityFilter
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from reflector.redis_cache import redis_cache
|
||||
|
||||
|
||||
class DiarizationSegment(TypedDict):
|
||||
"""Type definition for diarization segment containing speaker information"""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
speaker: int
|
||||
|
||||
|
||||
PUNC_RE = re.compile(r"[.;:?!…]")
|
||||
|
||||
profanity_filter = ProfanityFilter()
|
||||
|
||||
296
server/reflector/services/ics_sync.py
Normal file
296
server/reflector/services/ics_sync.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import hashlib
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import TypedDict
|
||||
|
||||
import httpx
|
||||
import pytz
|
||||
from icalendar import Calendar, Event
|
||||
from loguru import logger
|
||||
|
||||
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
|
||||
from reflector.db.rooms import Room, rooms_controller
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AttendeeData(TypedDict, total=False):
|
||||
email: str | None
|
||||
name: str | None
|
||||
status: str | None
|
||||
role: str | None
|
||||
|
||||
|
||||
class EventData(TypedDict):
|
||||
ics_uid: str
|
||||
title: str | None
|
||||
description: str | None
|
||||
location: str | None
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: list[AttendeeData]
|
||||
ics_raw_data: str
|
||||
|
||||
|
||||
class SyncStats(TypedDict):
|
||||
events_created: int
|
||||
events_updated: int
|
||||
events_deleted: int
|
||||
|
||||
|
||||
class ICSFetchService:
|
||||
def __init__(self):
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=30.0, headers={"User-Agent": "Reflector/1.0"}
|
||||
)
|
||||
|
||||
async def fetch_ics(self, url: str) -> str:
|
||||
response = await self.client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.text
|
||||
|
||||
def parse_ics(self, ics_content: str) -> Calendar:
|
||||
return Calendar.from_ical(ics_content)
|
||||
|
||||
def extract_room_events(
|
||||
self, calendar: Calendar, room_name: str, room_url: str
|
||||
) -> list[EventData]:
|
||||
events = []
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now - timedelta(hours=1)
|
||||
window_end = now + timedelta(hours=24)
|
||||
|
||||
for component in calendar.walk():
|
||||
if component.name == "VEVENT":
|
||||
# Skip cancelled events
|
||||
status = component.get("STATUS", "").upper()
|
||||
if status == "CANCELLED":
|
||||
continue
|
||||
|
||||
# Check if event matches this room
|
||||
if self._event_matches_room(component, room_name, room_url):
|
||||
event_data = self._parse_event(component)
|
||||
|
||||
# Only include events in our time window
|
||||
if (
|
||||
event_data
|
||||
and window_start <= event_data["start_time"] <= window_end
|
||||
):
|
||||
events.append(event_data)
|
||||
|
||||
return events
|
||||
|
||||
def _event_matches_room(self, event: Event, room_name: str, room_url: str) -> bool:
|
||||
location = str(event.get("LOCATION", ""))
|
||||
description = str(event.get("DESCRIPTION", ""))
|
||||
|
||||
# Only match full room URL (with or without protocol)
|
||||
patterns = [
|
||||
room_url, # Full URL with protocol
|
||||
room_url.replace("https://", ""), # Without https protocol
|
||||
room_url.replace("http://", ""), # Without http protocol
|
||||
]
|
||||
|
||||
# Check location and description for patterns
|
||||
text_to_check = f"{location} {description}".lower()
|
||||
|
||||
for pattern in patterns:
|
||||
if pattern.lower() in text_to_check:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _parse_event(self, event: Event) -> EventData | None:
|
||||
# Extract basic fields
|
||||
uid = str(event.get("UID", ""))
|
||||
summary = str(event.get("SUMMARY", ""))
|
||||
description = str(event.get("DESCRIPTION", ""))
|
||||
location = str(event.get("LOCATION", ""))
|
||||
|
||||
# Parse dates
|
||||
dtstart = event.get("DTSTART")
|
||||
dtend = event.get("DTEND")
|
||||
|
||||
if not dtstart:
|
||||
return None
|
||||
|
||||
# Convert to datetime
|
||||
start_time = self._normalize_datetime(
|
||||
dtstart.dt if hasattr(dtstart, "dt") else dtstart
|
||||
)
|
||||
end_time = (
|
||||
self._normalize_datetime(dtend.dt if hasattr(dtend, "dt") else dtend)
|
||||
if dtend
|
||||
else start_time + timedelta(hours=1)
|
||||
)
|
||||
|
||||
# Parse attendees
|
||||
attendees = self._parse_attendees(event)
|
||||
|
||||
# Get raw event data for storage
|
||||
raw_data = event.to_ical().decode("utf-8")
|
||||
|
||||
return {
|
||||
"ics_uid": uid,
|
||||
"title": summary,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"attendees": attendees,
|
||||
"ics_raw_data": raw_data,
|
||||
}
|
||||
|
||||
def _normalize_datetime(self, dt) -> datetime:
|
||||
# Handle date objects (all-day events)
|
||||
if isinstance(dt, date) and not isinstance(dt, datetime):
|
||||
# Convert to datetime at start of day in UTC
|
||||
dt = datetime.combine(dt, datetime.min.time())
|
||||
dt = pytz.UTC.localize(dt)
|
||||
elif isinstance(dt, datetime):
|
||||
# Add UTC timezone if naive
|
||||
if dt.tzinfo is None:
|
||||
dt = pytz.UTC.localize(dt)
|
||||
else:
|
||||
# Convert to UTC
|
||||
dt = dt.astimezone(pytz.UTC)
|
||||
|
||||
return dt
|
||||
|
||||
def _parse_attendees(self, event: Event) -> list[AttendeeData]:
|
||||
attendees = []
|
||||
|
||||
# Parse ATTENDEE properties
|
||||
for attendee in event.get("ATTENDEE", []):
|
||||
if not isinstance(attendee, list):
|
||||
attendee = [attendee]
|
||||
|
||||
for att in attendee:
|
||||
att_data: AttendeeData = {
|
||||
"email": str(att).replace("mailto:", "") if att else None,
|
||||
"name": att.params.get("CN") if hasattr(att, "params") else None,
|
||||
"status": att.params.get("PARTSTAT")
|
||||
if hasattr(att, "params")
|
||||
else None,
|
||||
"role": att.params.get("ROLE") if hasattr(att, "params") else None,
|
||||
}
|
||||
attendees.append(att_data)
|
||||
|
||||
# Add organizer
|
||||
organizer = event.get("ORGANIZER")
|
||||
if organizer:
|
||||
org_data: AttendeeData = {
|
||||
"email": str(organizer).replace("mailto:", "") if organizer else None,
|
||||
"name": organizer.params.get("CN")
|
||||
if hasattr(organizer, "params")
|
||||
else None,
|
||||
"role": "ORGANIZER",
|
||||
}
|
||||
attendees.append(org_data)
|
||||
|
||||
return attendees
|
||||
|
||||
|
||||
class ICSSyncService:
|
||||
def __init__(self):
|
||||
self.fetch_service = ICSFetchService()
|
||||
|
||||
async def sync_room_calendar(self, room: Room) -> dict:
|
||||
if not room.ics_enabled or not room.ics_url:
|
||||
return {"status": "skipped", "reason": "ICS not configured"}
|
||||
|
||||
try:
|
||||
# Check if it's time to sync
|
||||
if not self._should_sync(room):
|
||||
return {"status": "skipped", "reason": "Not time to sync yet"}
|
||||
|
||||
# Fetch ICS file
|
||||
ics_content = await self.fetch_service.fetch_ics(room.ics_url)
|
||||
|
||||
# Check if content changed
|
||||
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
|
||||
if room.ics_last_etag == content_hash:
|
||||
logger.info(f"No changes in ICS for room {room.id}")
|
||||
return {"status": "unchanged", "hash": content_hash}
|
||||
|
||||
# Parse calendar
|
||||
calendar = self.fetch_service.parse_ics(ics_content)
|
||||
|
||||
# Build room URL
|
||||
room_url = f"{settings.BASE_URL}/room/{room.name}"
|
||||
|
||||
# Extract matching events
|
||||
events = self.fetch_service.extract_room_events(
|
||||
calendar, room.name, room_url
|
||||
)
|
||||
|
||||
# Sync events to database
|
||||
sync_result = await self._sync_events_to_database(room.id, events)
|
||||
|
||||
# Update room sync metadata
|
||||
await rooms_controller.update(
|
||||
room,
|
||||
{
|
||||
"ics_last_sync": datetime.now(timezone.utc),
|
||||
"ics_last_etag": content_hash,
|
||||
},
|
||||
mutate=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"hash": content_hash,
|
||||
"events_found": len(events),
|
||||
**sync_result,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync ICS for room {room.id}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def _should_sync(self, room: Room) -> bool:
|
||||
if not room.ics_last_sync:
|
||||
return True
|
||||
|
||||
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
|
||||
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
||||
|
||||
async def _sync_events_to_database(
|
||||
self, room_id: str, events: list[EventData]
|
||||
) -> SyncStats:
|
||||
created = 0
|
||||
updated = 0
|
||||
|
||||
# Track current event IDs
|
||||
current_ics_uids = []
|
||||
|
||||
for event_data in events:
|
||||
# Create CalendarEvent object
|
||||
calendar_event = CalendarEvent(room_id=room_id, **event_data)
|
||||
|
||||
# Upsert event
|
||||
existing = await calendar_events_controller.get_by_ics_uid(
|
||||
room_id, event_data["ics_uid"]
|
||||
)
|
||||
|
||||
if existing:
|
||||
updated += 1
|
||||
else:
|
||||
created += 1
|
||||
|
||||
await calendar_events_controller.upsert(calendar_event)
|
||||
current_ics_uids.append(event_data["ics_uid"])
|
||||
|
||||
# Soft delete events that are no longer in calendar
|
||||
deleted = await calendar_events_controller.soft_delete_missing(
|
||||
room_id, current_ics_uids
|
||||
)
|
||||
|
||||
return {
|
||||
"events_created": created,
|
||||
"events_updated": updated,
|
||||
"events_deleted": deleted,
|
||||
}
|
||||
|
||||
|
||||
# Global instance
|
||||
ics_sync_service = ICSSyncService()
|
||||
@@ -26,7 +26,6 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_BACKEND: str = "whisper"
|
||||
TRANSCRIPT_URL: str | None = None
|
||||
TRANSCRIPT_TIMEOUT: int = 90
|
||||
TRANSCRIPT_FILE_TIMEOUT: int = 600
|
||||
|
||||
# Audio Transcription: modal backend
|
||||
TRANSCRIPT_MODAL_API_KEY: str | None = None
|
||||
@@ -67,14 +66,10 @@ class Settings(BaseSettings):
|
||||
DIARIZATION_ENABLED: bool = True
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
DIARIZATION_FILE_TIMEOUT: int = 600
|
||||
|
||||
# Diarization: modal backend
|
||||
DIARIZATION_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Diarization: local pyannote.audio
|
||||
DIARIZATION_PYANNOTE_AUTH_TOKEN: str | None = None
|
||||
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
|
||||
@@ -1,23 +1,10 @@
|
||||
"""
|
||||
Process audio file with diarization support
|
||||
===========================================
|
||||
|
||||
Extended version of process.py that includes speaker diarization.
|
||||
This tool processes audio files locally without requiring the full server infrastructure.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
@@ -28,43 +15,7 @@ from reflector.processors import (
|
||||
TranscriptTopicDetectorProcessor,
|
||||
TranscriptTranslatorAutoProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor, Processor
|
||||
from reflector.processors.types import (
|
||||
AudioDiarizationInput,
|
||||
TitleSummary,
|
||||
TitleSummaryWithId,
|
||||
)
|
||||
|
||||
|
||||
class TopicCollectorProcessor(Processor):
|
||||
"""Collect topics for diarization"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topics: List[TitleSummaryWithId] = []
|
||||
self._topic_id = 0
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
# Convert to TitleSummaryWithId and collect
|
||||
self._topic_id += 1
|
||||
topic_with_id = TitleSummaryWithId(
|
||||
id=str(self._topic_id),
|
||||
title=data.title,
|
||||
summary=data.summary,
|
||||
timestamp=data.timestamp,
|
||||
duration=data.duration,
|
||||
transcript=data.transcript,
|
||||
)
|
||||
self.topics.append(topic_with_id)
|
||||
|
||||
# Pass through the original topic
|
||||
await self.emit(data)
|
||||
|
||||
def get_topics(self) -> List[TitleSummaryWithId]:
|
||||
return self.topics
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
|
||||
|
||||
async def process_audio_file(
|
||||
@@ -73,40 +24,18 @@ async def process_audio_file(
|
||||
only_transcript=False,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="pyannote",
|
||||
):
|
||||
# Create temp file for audio if diarization is enabled
|
||||
audio_temp_path = None
|
||||
if enable_diarization:
|
||||
audio_temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
audio_temp_path = audio_temp_file.name
|
||||
audio_temp_file.close()
|
||||
|
||||
# Create processor for collecting topics
|
||||
topic_collector = TopicCollectorProcessor()
|
||||
|
||||
# Build pipeline for audio processing
|
||||
processors = []
|
||||
|
||||
# Add audio file writer at the beginning if diarization is enabled
|
||||
if enable_diarization:
|
||||
processors.append(AudioFileWriterProcessor(audio_temp_path))
|
||||
|
||||
# Add the rest of the processors
|
||||
processors += [
|
||||
# build pipeline for audio processing
|
||||
processors = [
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorAutoProcessor.as_threaded(),
|
||||
]
|
||||
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
# Collect topics for diarization
|
||||
topic_collector,
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
@@ -115,14 +44,14 @@ async def process_audio_file(
|
||||
),
|
||||
]
|
||||
|
||||
# Create main pipeline
|
||||
# transcription output
|
||||
pipeline = Pipeline(*processors)
|
||||
pipeline.set_pref("audio:source_language", source_language)
|
||||
pipeline.set_pref("audio:target_language", target_language)
|
||||
pipeline.describe()
|
||||
pipeline.on(event_callback)
|
||||
|
||||
# Start processing audio
|
||||
# start processing audio
|
||||
logger.info(f"Opening {filename}")
|
||||
container = av.open(filename)
|
||||
try:
|
||||
@@ -133,242 +62,43 @@ async def process_audio_file(
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
# Run diarization if enabled and we have topics
|
||||
if enable_diarization and not only_transcript and audio_temp_path:
|
||||
topics = topic_collector.get_topics()
|
||||
|
||||
if topics:
|
||||
logger.info(f"Starting diarization with {len(topics)} topics")
|
||||
|
||||
try:
|
||||
from reflector.processors import AudioDiarizationAutoProcessor
|
||||
|
||||
diarization_processor = AudioDiarizationAutoProcessor(
|
||||
name=diarization_backend
|
||||
)
|
||||
|
||||
diarization_processor.set_pipeline(pipeline)
|
||||
|
||||
# For Modal backend, we need to upload the file to S3 first
|
||||
if diarization_backend == "modal":
|
||||
from datetime import datetime
|
||||
|
||||
from reflector.storage import get_transcripts_storage
|
||||
from reflector.utils.s3_temp_file import S3TemporaryFile
|
||||
|
||||
storage = get_transcripts_storage()
|
||||
|
||||
# Generate a unique filename in evaluation folder
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with S3TemporaryFile(storage, audio_filename) as s3_file:
|
||||
# Read and upload the audio file
|
||||
with open(audio_temp_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
audio_url = await s3_file.upload(audio_data)
|
||||
logger.info(f"Uploaded audio to S3: {audio_filename}")
|
||||
|
||||
# Create diarization input with S3 URL
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
# File will be automatically cleaned up when exiting the context
|
||||
else:
|
||||
# For local backend, use local file path
|
||||
audio_url = audio_temp_path
|
||||
|
||||
# Create diarization input
|
||||
diarization_input = AudioDiarizationInput(
|
||||
audio_url=audio_url, topics=topics
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
await diarization_processor.push(diarization_input)
|
||||
await diarization_processor.flush()
|
||||
|
||||
logger.info("Diarization complete")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import diarization dependencies: {e}")
|
||||
logger.error(
|
||||
"Install with: uv pip install pyannote.audio torch torchaudio"
|
||||
)
|
||||
logger.error(
|
||||
"And set HF_TOKEN environment variable for pyannote models"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
raise SystemExit(1)
|
||||
else:
|
||||
logger.warning("Skipping diarization: no topics available")
|
||||
|
||||
# Clean up temp file
|
||||
if audio_temp_path:
|
||||
try:
|
||||
Path(audio_temp_path).unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp file {audio_temp_path}: {e}")
|
||||
|
||||
logger.info("All done!")
|
||||
|
||||
|
||||
async def process_file_pipeline(
|
||||
filename: str,
|
||||
event_callback,
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
enable_diarization=True,
|
||||
diarization_backend="modal",
|
||||
):
|
||||
"""Process audio/video file using the optimized file pipeline"""
|
||||
try:
|
||||
from reflector.db import database
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
|
||||
await database.connect()
|
||||
try:
|
||||
# Create a temporary transcript for processing
|
||||
transcript = await transcripts_controller.add(
|
||||
"",
|
||||
source_kind=SourceKind.FILE,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
)
|
||||
|
||||
# Process the file
|
||||
pipeline = PipelineMainFile(transcript_id=transcript.id)
|
||||
await pipeline.process(Path(filename))
|
||||
|
||||
logger.info("File pipeline processing complete")
|
||||
|
||||
finally:
|
||||
await database.disconnect()
|
||||
except ImportError as e:
|
||||
logger.error(f"File pipeline not available: {e}")
|
||||
logger.info("Falling back to stream pipeline")
|
||||
# Fall back to stream pipeline
|
||||
await process_audio_file(
|
||||
filename,
|
||||
event_callback,
|
||||
only_transcript=False,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
enable_diarization=enable_diarization,
|
||||
diarization_backend=diarization_backend,
|
||||
)
|
||||
logger.info("All done !")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio files with optional speaker diarization"
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||
parser.add_argument(
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Use streaming pipeline (original frame-based processing)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only-transcript",
|
||||
"-t",
|
||||
action="store_true",
|
||||
help="Only generate transcript without topics/summaries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-language", default="en", help="Source language code (default: en)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-language", default="en", help="Target language code (default: en)"
|
||||
)
|
||||
parser.add_argument("--only-transcript", "-t", action="store_true")
|
||||
parser.add_argument("--source-language", default="en")
|
||||
parser.add_argument("--target-language", default="en")
|
||||
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
|
||||
parser.add_argument(
|
||||
"--enable-diarization",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="Enable speaker diarization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
default="pyannote",
|
||||
choices=["pyannote", "modal"],
|
||||
help="Diarization backend to use (default: pyannote)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "REDIS_HOST" not in os.environ:
|
||||
os.environ["REDIS_HOST"] = "localhost"
|
||||
|
||||
output_fd = None
|
||||
if args.output:
|
||||
output_fd = open(args.output, "w")
|
||||
|
||||
async def event_callback(event: PipelineEvent):
|
||||
processor = event.processor
|
||||
data = event.data
|
||||
|
||||
# Ignore internal processors
|
||||
if processor in (
|
||||
"AudioChunkerProcessor",
|
||||
"AudioMergeProcessor",
|
||||
"AudioFileWriterProcessor",
|
||||
"TopicCollectorProcessor",
|
||||
"BroadcastProcessor",
|
||||
):
|
||||
# ignore some processor
|
||||
if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
|
||||
return
|
||||
|
||||
# If diarization is enabled, skip the original topic events from the pipeline
|
||||
# The diarization processor will emit the same topics but with speaker info
|
||||
if processor == "TranscriptTopicDetectorProcessor" and args.enable_diarization:
|
||||
return
|
||||
|
||||
# Log all events
|
||||
logger.info(f"Event: {processor} - {type(data).__name__}")
|
||||
|
||||
# Write to output
|
||||
logger.info(f"Event: {event}")
|
||||
if output_fd:
|
||||
output_fd.write(event.model_dump_json())
|
||||
output_fd.write("\n")
|
||||
output_fd.flush()
|
||||
|
||||
if args.stream:
|
||||
# Use original streaming pipeline
|
||||
asyncio.run(
|
||||
process_audio_file(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use optimized file pipeline (default)
|
||||
asyncio.run(
|
||||
process_file_pipeline(
|
||||
args.source,
|
||||
event_callback,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
enable_diarization=args.enable_diarization,
|
||||
diarization_backend=args.diarization_backend,
|
||||
)
|
||||
asyncio.run(
|
||||
process_audio_file(
|
||||
args.source,
|
||||
event_callback,
|
||||
only_transcript=args.only_transcript,
|
||||
source_language=args.source_language,
|
||||
target_language=args.target_language,
|
||||
)
|
||||
)
|
||||
|
||||
if output_fd:
|
||||
output_fd.close()
|
||||
|
||||
@@ -42,6 +42,11 @@ class Room(BaseModel):
|
||||
recording_type: str
|
||||
recording_trigger: str
|
||||
is_shared: bool
|
||||
ics_url: Optional[str] = None
|
||||
ics_fetch_interval: int = 300
|
||||
ics_enabled: bool = False
|
||||
ics_last_sync: Optional[datetime] = None
|
||||
ics_last_etag: Optional[str] = None
|
||||
|
||||
|
||||
class Meeting(BaseModel):
|
||||
@@ -64,18 +69,24 @@ class CreateRoom(BaseModel):
|
||||
recording_type: str
|
||||
recording_trigger: str
|
||||
is_shared: bool
|
||||
ics_url: Optional[str] = None
|
||||
ics_fetch_interval: int = 300
|
||||
ics_enabled: bool = False
|
||||
|
||||
|
||||
class UpdateRoom(BaseModel):
|
||||
name: str
|
||||
zulip_auto_post: bool
|
||||
zulip_stream: str
|
||||
zulip_topic: str
|
||||
is_locked: bool
|
||||
room_mode: str
|
||||
recording_type: str
|
||||
recording_trigger: str
|
||||
is_shared: bool
|
||||
name: Optional[str] = None
|
||||
zulip_auto_post: Optional[bool] = None
|
||||
zulip_stream: Optional[str] = None
|
||||
zulip_topic: Optional[str] = None
|
||||
is_locked: Optional[bool] = None
|
||||
room_mode: Optional[str] = None
|
||||
recording_type: Optional[str] = None
|
||||
recording_trigger: Optional[str] = None
|
||||
is_shared: Optional[bool] = None
|
||||
ics_url: Optional[str] = None
|
||||
ics_fetch_interval: Optional[int] = None
|
||||
ics_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class DeletionStatus(BaseModel):
|
||||
@@ -117,6 +128,9 @@ async def rooms_create(
|
||||
recording_type=room.recording_type,
|
||||
recording_trigger=room.recording_trigger,
|
||||
is_shared=room.is_shared,
|
||||
ics_url=room.ics_url,
|
||||
ics_fetch_interval=room.ics_fetch_interval,
|
||||
ics_enabled=room.ics_enabled,
|
||||
)
|
||||
|
||||
|
||||
@@ -209,3 +223,217 @@ async def rooms_create_meeting(
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meeting
|
||||
|
||||
|
||||
class ICSStatus(BaseModel):
|
||||
status: str
|
||||
last_sync: Optional[datetime] = None
|
||||
next_sync: Optional[datetime] = None
|
||||
last_etag: Optional[str] = None
|
||||
events_count: int = 0
|
||||
|
||||
|
||||
class ICSSyncResult(BaseModel):
|
||||
status: str
|
||||
hash: Optional[str] = None
|
||||
events_found: int = 0
|
||||
events_created: int = 0
|
||||
events_updated: int = 0
|
||||
events_deleted: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/rooms/{room_name}/ics/sync", response_model=ICSSyncResult)
|
||||
async def rooms_sync_ics(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
if user_id != room.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only room owner can trigger ICS sync"
|
||||
)
|
||||
|
||||
if not room.ics_enabled or not room.ics_url:
|
||||
raise HTTPException(status_code=400, detail="ICS not configured for this room")
|
||||
|
||||
from reflector.services.ics_sync import ics_sync_service
|
||||
|
||||
result = await ics_sync_service.sync_room_calendar(room)
|
||||
|
||||
if result["status"] == "error":
|
||||
raise HTTPException(
|
||||
status_code=500, detail=result.get("error", "Unknown error")
|
||||
)
|
||||
|
||||
return ICSSyncResult(**result)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_name}/ics/status", response_model=ICSStatus)
|
||||
async def rooms_ics_status(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
if user_id != room.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Only room owner can view ICS status"
|
||||
)
|
||||
|
||||
next_sync = None
|
||||
if room.ics_enabled and room.ics_last_sync:
|
||||
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
|
||||
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
room.id, include_deleted=False
|
||||
)
|
||||
|
||||
return ICSStatus(
|
||||
status="enabled" if room.ics_enabled else "disabled",
|
||||
last_sync=room.ics_last_sync,
|
||||
next_sync=next_sync,
|
||||
last_etag=room.ics_last_etag,
|
||||
events_count=len(events),
|
||||
)
|
||||
|
||||
|
||||
class CalendarEventResponse(BaseModel):
|
||||
id: str
|
||||
room_id: str
|
||||
ics_uid: str
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
attendees: Optional[list[dict]] = None
|
||||
location: Optional[str] = None
|
||||
last_synced: datetime
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@router.get("/rooms/{room_name}/meetings", response_model=list[CalendarEventResponse])
|
||||
async def rooms_list_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
room.id, include_deleted=False
|
||||
)
|
||||
|
||||
if user_id != room.user_id:
|
||||
for event in events:
|
||||
event.description = None
|
||||
event.attendees = None
|
||||
|
||||
return events
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_name}/meetings/upcoming", response_model=list[CalendarEventResponse]
|
||||
)
|
||||
async def rooms_list_upcoming_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
minutes_ahead: int = 30,
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
|
||||
events = await calendar_events_controller.get_upcoming(
|
||||
room.id, minutes_ahead=minutes_ahead
|
||||
)
|
||||
|
||||
if user_id != room.user_id:
|
||||
for event in events:
|
||||
event.description = None
|
||||
event.attendees = None
|
||||
|
||||
return events
|
||||
|
||||
|
||||
@router.get("/rooms/{room_name}/meetings/active", response_model=list[Meeting])
|
||||
async def rooms_list_active_meetings(
|
||||
room_name: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
"""List all active meetings for a room (supports multiple active meetings)"""
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
meetings = await meetings_controller.get_all_active_for_room(
|
||||
room=room, current_time=current_time
|
||||
)
|
||||
|
||||
# Hide host URLs from non-owners
|
||||
if user_id != room.user_id:
|
||||
for meeting in meetings:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meetings
|
||||
|
||||
|
||||
@router.post("/rooms/{room_name}/meetings/{meeting_id}/join", response_model=Meeting)
|
||||
async def rooms_join_meeting(
|
||||
room_name: str,
|
||||
meeting_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
"""Join a specific meeting by ID"""
|
||||
user_id = user["sub"] if user else None
|
||||
room = await rooms_controller.get_by_name(room_name)
|
||||
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||
|
||||
if not meeting:
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if meeting.room_id != room.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Meeting does not belong to this room"
|
||||
)
|
||||
|
||||
if not meeting.is_active:
|
||||
raise HTTPException(status_code=400, detail="Meeting is not active")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
if meeting.end_date <= current_time:
|
||||
raise HTTPException(status_code=400, detail="Meeting has ended")
|
||||
|
||||
# Hide host URL from non-owners
|
||||
if user_id != room.user_id:
|
||||
meeting.host_room_url = ""
|
||||
|
||||
return meeting
|
||||
|
||||
@@ -160,7 +160,6 @@ async def transcripts_search(
|
||||
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
|
||||
offset: SearchOffsetParam = 0,
|
||||
room_id: Optional[str] = None,
|
||||
source_kind: Optional[SourceKind] = None,
|
||||
user: Annotated[
|
||||
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
||||
] = None,
|
||||
@@ -174,12 +173,7 @@ async def transcripts_search(
|
||||
user_id = user["sub"] if user else None
|
||||
|
||||
search_params = SearchParameters(
|
||||
query_text=q,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
source_kind=source_kind,
|
||||
query_text=q, limit=limit, offset=offset, user_id=user_id, room_id=room_id
|
||||
)
|
||||
|
||||
results, total = await search_controller.search_transcripts(search_params)
|
||||
|
||||
@@ -68,8 +68,13 @@ async def whereby_webhook(event: WherebyWebhookEvent, request: Request):
|
||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||
|
||||
if event.type in ["room.client.joined", "room.client.left"]:
|
||||
await meetings_controller.update_meeting(
|
||||
meeting.id, num_clients=event.data["numClients"]
|
||||
)
|
||||
update_data = {"num_clients": event.data["numClients"]}
|
||||
|
||||
# Clear grace period if participant joined
|
||||
if event.type == "room.client.joined" and event.data["numClients"] > 0:
|
||||
if meeting.last_participant_left_at:
|
||||
update_data["last_participant_left_at"] = None
|
||||
|
||||
await meetings_controller.update_meeting(meeting.id, **update_data)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -19,6 +19,7 @@ else:
|
||||
"reflector.pipelines.main_live_pipeline",
|
||||
"reflector.worker.healthcheck",
|
||||
"reflector.worker.process",
|
||||
"reflector.worker.ics_sync",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -36,6 +37,14 @@ else:
|
||||
"task": "reflector.worker.process.reprocess_failed_recordings",
|
||||
"schedule": crontab(hour=5, minute=0), # Midnight EST
|
||||
},
|
||||
"sync_all_ics_calendars": {
|
||||
"task": "reflector.worker.ics_sync.sync_all_ics_calendars",
|
||||
"schedule": 60.0, # Run every minute to check which rooms need sync
|
||||
},
|
||||
"pre_create_upcoming_meetings": {
|
||||
"task": "reflector.worker.ics_sync.pre_create_upcoming_meetings",
|
||||
"schedule": 30.0, # Run every 30 seconds to pre-create meetings
|
||||
},
|
||||
}
|
||||
|
||||
if settings.HEALTHCHECK_URL:
|
||||
|
||||
209
server/reflector/worker/ics_sync.py
Normal file
209
server/reflector/worker/ics_sync.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import structlog
|
||||
from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms, rooms_controller
|
||||
from reflector.services.ics_sync import ics_sync_service
|
||||
from reflector.whereby import create_meeting, upload_logo
|
||||
|
||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||
|
||||
|
||||
@shared_task
|
||||
def sync_room_ics(room_id: str):
|
||||
asynctask(_sync_room_ics_async(room_id))
|
||||
|
||||
|
||||
async def _sync_room_ics_async(room_id: str):
|
||||
try:
|
||||
room = await rooms_controller.get_by_id(room_id)
|
||||
if not room:
|
||||
logger.warning("Room not found for ICS sync", room_id=room_id)
|
||||
return
|
||||
|
||||
if not room.ics_enabled or not room.ics_url:
|
||||
logger.debug("ICS not enabled for room", room_id=room_id)
|
||||
return
|
||||
|
||||
logger.info("Starting ICS sync for room", room_id=room_id, room_name=room.name)
|
||||
result = await ics_sync_service.sync_room_calendar(room)
|
||||
|
||||
if result["status"] == "success":
|
||||
logger.info(
|
||||
"ICS sync completed successfully",
|
||||
room_id=room_id,
|
||||
events_found=result.get("events_found", 0),
|
||||
events_created=result.get("events_created", 0),
|
||||
events_updated=result.get("events_updated", 0),
|
||||
events_deleted=result.get("events_deleted", 0),
|
||||
)
|
||||
elif result["status"] == "unchanged":
|
||||
logger.debug("ICS content unchanged", room_id=room_id)
|
||||
elif result["status"] == "error":
|
||||
logger.error("ICS sync failed", room_id=room_id, error=result.get("error"))
|
||||
else:
|
||||
logger.debug(
|
||||
"ICS sync skipped", room_id=room_id, reason=result.get("reason")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error during ICS sync", room_id=room_id, error=str(e))
|
||||
|
||||
|
||||
@shared_task
|
||||
def sync_all_ics_calendars():
|
||||
asynctask(_sync_all_ics_calendars_async())
|
||||
|
||||
|
||||
async def _sync_all_ics_calendars_async():
|
||||
try:
|
||||
logger.info("Starting sync for all ICS-enabled rooms")
|
||||
|
||||
# Get ALL rooms - not filtered by is_shared
|
||||
query = rooms.select().where(
|
||||
rooms.c.ics_enabled == True, rooms.c.ics_url != None
|
||||
)
|
||||
all_rooms = await get_database().fetch_all(query)
|
||||
ics_enabled_rooms = list(all_rooms)
|
||||
|
||||
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
||||
|
||||
for room_data in ics_enabled_rooms:
|
||||
room_id = room_data["id"]
|
||||
room = await rooms_controller.get_by_id(room_id)
|
||||
|
||||
if not room:
|
||||
continue
|
||||
|
||||
if not _should_sync(room):
|
||||
logger.debug("Skipping room, not time to sync yet", room_id=room_id)
|
||||
continue
|
||||
|
||||
sync_room_ics.delay(room_id)
|
||||
|
||||
logger.info("Queued sync tasks for all eligible rooms")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in sync_all_ics_calendars", error=str(e))
|
||||
|
||||
|
||||
def _should_sync(room) -> bool:
|
||||
if not room.ics_last_sync:
|
||||
return True
|
||||
|
||||
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
|
||||
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
||||
|
||||
|
||||
@shared_task
|
||||
def pre_create_upcoming_meetings():
|
||||
asynctask(_pre_create_upcoming_meetings_async())
|
||||
|
||||
|
||||
async def _pre_create_upcoming_meetings_async():
|
||||
try:
|
||||
logger.info("Starting pre-creation of upcoming meetings")
|
||||
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
|
||||
# Get ALL rooms with ICS enabled
|
||||
query = rooms.select().where(
|
||||
rooms.c.ics_enabled == True, rooms.c.ics_url != None
|
||||
)
|
||||
all_rooms = await get_database().fetch_all(query)
|
||||
now = datetime.now(timezone.utc)
|
||||
pre_create_window = now + timedelta(minutes=1)
|
||||
|
||||
for room_data in all_rooms:
|
||||
room_id = room_data["id"]
|
||||
room = await rooms_controller.get_by_id(room_id)
|
||||
|
||||
if not room:
|
||||
continue
|
||||
|
||||
events = await calendar_events_controller.get_upcoming(
|
||||
room_id, minutes_ahead=2
|
||||
)
|
||||
|
||||
for event in events:
|
||||
if event.start_time <= pre_create_window:
|
||||
existing_meeting = await meetings_controller.get_by_calendar_event(
|
||||
event.id
|
||||
)
|
||||
|
||||
if not existing_meeting:
|
||||
logger.info(
|
||||
"Pre-creating meeting for calendar event",
|
||||
room_id=room_id,
|
||||
event_id=event.id,
|
||||
event_title=event.title,
|
||||
)
|
||||
|
||||
try:
|
||||
end_date = event.end_time or (
|
||||
event.start_time + timedelta(hours=1)
|
||||
)
|
||||
|
||||
whereby_meeting = await create_meeting(
|
||||
event.title or "Scheduled Meeting",
|
||||
end_date=end_date,
|
||||
room=room,
|
||||
)
|
||||
await upload_logo(
|
||||
whereby_meeting["roomName"], "./images/logo.png"
|
||||
)
|
||||
|
||||
meeting = await meetings_controller.create(
|
||||
id=whereby_meeting["meetingId"],
|
||||
room_name=whereby_meeting["roomName"],
|
||||
room_url=whereby_meeting["roomUrl"],
|
||||
host_room_url=whereby_meeting["hostRoomUrl"],
|
||||
start_date=datetime.fromisoformat(
|
||||
whereby_meeting["startDate"]
|
||||
),
|
||||
end_date=datetime.fromisoformat(
|
||||
whereby_meeting["endDate"]
|
||||
),
|
||||
user_id=room.user_id,
|
||||
room=room,
|
||||
calendar_event_id=event.id,
|
||||
calendar_metadata={
|
||||
"title": event.title,
|
||||
"description": event.description,
|
||||
"attendees": event.attendees,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Meeting pre-created successfully",
|
||||
meeting_id=meeting.id,
|
||||
event_id=event.id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to pre-create meeting",
|
||||
room_id=room_id,
|
||||
event_id=event.id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
logger.info("Completed pre-creation check for upcoming meetings")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in pre_create_upcoming_meetings", error=str(e))
|
||||
|
||||
|
||||
def asynctask(coro):
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from urllib.parse import unquote
|
||||
|
||||
import av
|
||||
@@ -14,8 +14,7 @@ from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.recordings import Recording, recordings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||
from reflector.pipelines.main_live_pipeline import asynctask
|
||||
from reflector.pipelines.main_live_pipeline import asynctask, task_pipeline_process
|
||||
from reflector.settings import settings
|
||||
from reflector.whereby import get_room_sessions
|
||||
|
||||
@@ -141,30 +140,82 @@ async def process_recording(bucket_name: str, object_key: str):
|
||||
|
||||
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||
|
||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||
task_pipeline_process.delay(transcript_id=transcript.id)
|
||||
|
||||
|
||||
@shared_task
|
||||
@asynctask
|
||||
async def process_meetings():
|
||||
"""
|
||||
Checks which meetings are still active and deactivates those that have ended.
|
||||
Supports multiple active meetings per room and grace period logic.
|
||||
"""
|
||||
logger.info("Processing meetings")
|
||||
meetings = await meetings_controller.get_all_active()
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
for meeting in meetings:
|
||||
is_active = False
|
||||
should_deactivate = False
|
||||
end_date = meeting.end_date
|
||||
if end_date.tzinfo is None:
|
||||
end_date = end_date.replace(tzinfo=timezone.utc)
|
||||
if end_date > datetime.now(timezone.utc):
|
||||
|
||||
# Check if meeting has passed its scheduled end time
|
||||
if end_date <= current_time:
|
||||
# For calendar meetings, force close 30 minutes after scheduled end
|
||||
if meeting.calendar_event_id:
|
||||
if current_time > end_date + timedelta(minutes=30):
|
||||
should_deactivate = True
|
||||
logger.info(
|
||||
"Meeting %s forced closed 30 min after calendar end", meeting.id
|
||||
)
|
||||
else:
|
||||
# Unscheduled meetings follow normal closure rules
|
||||
should_deactivate = True
|
||||
|
||||
# Check Whereby room sessions only if not already deactivating
|
||||
if not should_deactivate and end_date > current_time:
|
||||
response = await get_room_sessions(meeting.room_name)
|
||||
room_sessions = response.get("results", [])
|
||||
is_active = not room_sessions or any(
|
||||
has_active_sessions = room_sessions and any(
|
||||
rs["endedAt"] is None for rs in room_sessions
|
||||
)
|
||||
if not is_active:
|
||||
|
||||
if not has_active_sessions:
|
||||
# No active sessions - check grace period
|
||||
if meeting.num_clients == 0:
|
||||
if meeting.last_participant_left_at:
|
||||
# Check if grace period has expired
|
||||
grace_period = timedelta(minutes=meeting.grace_period_minutes)
|
||||
if (
|
||||
current_time
|
||||
> meeting.last_participant_left_at + grace_period
|
||||
):
|
||||
should_deactivate = True
|
||||
logger.info("Meeting %s grace period expired", meeting.id)
|
||||
else:
|
||||
# First time all participants left, record the time
|
||||
await meetings_controller.update_meeting(
|
||||
meeting.id, last_participant_left_at=current_time
|
||||
)
|
||||
logger.info(
|
||||
"Meeting %s marked empty at %s", meeting.id, current_time
|
||||
)
|
||||
else:
|
||||
# Has active sessions - clear grace period if set
|
||||
if meeting.last_participant_left_at:
|
||||
await meetings_controller.update_meeting(
|
||||
meeting.id, last_participant_left_at=None
|
||||
)
|
||||
logger.info(
|
||||
"Meeting %s reactivated - participant rejoined", meeting.id
|
||||
)
|
||||
|
||||
if should_deactivate:
|
||||
await meetings_controller.update_meeting(meeting.id, is_active=False)
|
||||
logger.info("Meeting %s is deactivated", meeting.id)
|
||||
|
||||
logger.info("Processed meetings")
|
||||
logger.info("Processed %d meetings", len(meetings))
|
||||
|
||||
|
||||
@shared_task
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: ''
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
authorization:
|
||||
- DUMMY_API_KEY
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '0'
|
||||
host:
|
||||
- monadical-sas--reflector-diarizer-web.modal.run
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3×tamp=0
|
||||
response:
|
||||
body:
|
||||
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000
|
||||
Content-Length:
|
||||
- '220'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 13 Aug 2025 18:25:34 GMT
|
||||
Modal-Function-Call-Id:
|
||||
- fc-01K2JAVNEP6N7Y1Y7W3T98BCXK
|
||||
Vary:
|
||||
- accept-encoding
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,46 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
|
||||
"language": "en", "batch": true}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
authorization:
|
||||
- DUMMY_API_KEY
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '136'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
|
||||
response:
|
||||
body:
|
||||
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
|
||||
with Reflector. a Q teenage product that revolutionizes audio processing.
|
||||
With reflector, I can easily convert any audio into accurate transcription.
|
||||
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000
|
||||
Content-Length:
|
||||
- '1933'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 13 Aug 2025 18:26:59 GMT
|
||||
Modal-Function-Call-Id:
|
||||
- fc-01K2JAWC7GAMKX4DSJ21WV31NG
|
||||
Vary:
|
||||
- accept-encoding
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,84 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"audio_file_url": "https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3",
|
||||
"language": "en", "batch": true}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
authorization:
|
||||
- DUMMY_API_KEY
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '136'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- monadical-sas--reflector-transcriber-parakeet-web.modal.run
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://monadical-sas--reflector-transcriber-parakeet-web.modal.run/v1/audio/transcriptions-from-url
|
||||
response:
|
||||
body:
|
||||
string: '{"text":"Hi there everyone. Today I want to share my incredible experience
|
||||
with Reflector. a Q teenage product that revolutionizes audio processing.
|
||||
With reflector, I can easily convert any audio into accurate transcription.
|
||||
saving me hours of tedious manual work.","words":[{"word":"Hi","start":0.87,"end":1.19},{"word":"there","start":1.19,"end":1.35},{"word":"everyone.","start":1.51,"end":1.83},{"word":"Today","start":2.63,"end":2.87},{"word":"I","start":3.36,"end":3.52},{"word":"want","start":3.6,"end":3.76},{"word":"to","start":3.76,"end":3.92},{"word":"share","start":3.92,"end":4.16},{"word":"my","start":4.16,"end":4.4},{"word":"incredible","start":4.32,"end":4.96},{"word":"experience","start":4.96,"end":5.44},{"word":"with","start":5.44,"end":5.68},{"word":"Reflector.","start":5.68,"end":6.24},{"word":"a","start":6.93,"end":7.01},{"word":"Q","start":7.01,"end":7.17},{"word":"teenage","start":7.25,"end":7.65},{"word":"product","start":7.89,"end":8.29},{"word":"that","start":8.29,"end":8.61},{"word":"revolutionizes","start":8.61,"end":9.65},{"word":"audio","start":9.65,"end":10.05},{"word":"processing.","start":10.05,"end":10.53},{"word":"With","start":11.27,"end":11.43},{"word":"reflector,","start":11.51,"end":12.15},{"word":"I","start":12.31,"end":12.39},{"word":"can","start":12.39,"end":12.55},{"word":"easily","start":12.55,"end":12.95},{"word":"convert","start":12.95,"end":13.43},{"word":"any","start":13.43,"end":13.67},{"word":"audio","start":13.67,"end":13.99},{"word":"into","start":14.98,"end":15.06},{"word":"accurate","start":15.22,"end":15.54},{"word":"transcription.","start":15.7,"end":16.34},{"word":"saving","start":16.99,"end":17.15},{"word":"me","start":17.31,"end":17.47},{"word":"hours","start":17.47,"end":17.87},{"word":"of","start":17.87,"end":18.11},{"word":"tedious","start":18.11,"end":18.67},{"word":"manual","start":18.67,"end":19.07},{"word":"work.","start":19.07,"end":19.31}]}'
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000
|
||||
Content-Length:
|
||||
- '1933'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 13 Aug 2025 18:27:02 GMT
|
||||
Modal-Function-Call-Id:
|
||||
- fc-01K2JAYZ1AR2HE422VJVKBWX9Z
|
||||
Vary:
|
||||
- accept-encoding
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: ''
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
authorization:
|
||||
- DUMMY_API_KEY
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '0'
|
||||
host:
|
||||
- monadical-sas--reflector-diarizer-web.modal.run
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://monadical-sas--reflector-diarizer-web.modal.run/diarize?audio_file_url=https%3A%2F%2Freflector-github-pytest.s3.us-east-1.amazonaws.com%2Ftest_mathieu_hello.mp3×tamp=0
|
||||
response:
|
||||
body:
|
||||
string: '{"diarization":[{"start":0.823,"end":1.91,"speaker":0},{"start":2.572,"end":6.409,"speaker":0},{"start":6.783,"end":10.62,"speaker":0},{"start":11.231,"end":14.168,"speaker":0},{"start":14.796,"end":19.295,"speaker":0}]}'
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000
|
||||
Content-Length:
|
||||
- '220'
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 13 Aug 2025 18:27:18 GMT
|
||||
Modal-Function-Call-Id:
|
||||
- fc-01K2JAZ1M34NQRJK03CCFK95D6
|
||||
Vary:
|
||||
- accept-encoding
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -5,29 +5,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def settings_configuration():
|
||||
# theses settings are linked to monadical for pytest-recording
|
||||
# if a fork is done, they have to provide their own url when cassettes needs to be updated
|
||||
# modal api keys has to be defined by the user
|
||||
from reflector.settings import settings
|
||||
|
||||
settings.TRANSCRIPT_BACKEND = "modal"
|
||||
settings.TRANSCRIPT_URL = (
|
||||
"https://monadical-sas--reflector-transcriber-parakeet-web.modal.run"
|
||||
)
|
||||
settings.DIARIZATION_BACKEND = "modal"
|
||||
settings.DIARIZATION_URL = "https://monadical-sas--reflector-diarizer-web.modal.run"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config():
|
||||
"""VCR configuration to filter sensitive headers"""
|
||||
return {
|
||||
"filter_headers": [("authorization", "DUMMY_API_KEY")],
|
||||
}
|
||||
|
||||
|
||||
# Pytest-docker configuration
|
||||
@pytest.fixture(scope="session")
|
||||
def docker_compose_file(pytestconfig):
|
||||
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
version: "3.8"
|
||||
version: '3.8'
|
||||
services:
|
||||
postgres_test:
|
||||
image: postgres:17
|
||||
image: postgres:15
|
||||
environment:
|
||||
POSTGRES_DB: reflector_test
|
||||
POSTGRES_USER: test_user
|
||||
@@ -10,4 +10,4 @@ services:
|
||||
- "15432:5432"
|
||||
command: postgres -c fsync=off -c synchronous_commit=off -c full_page_writes=off
|
||||
tmpfs:
|
||||
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g
|
||||
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g
|
||||
351
server/tests/test_calendar_event.py
Normal file
351
server/tests/test_calendar_event.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Tests for CalendarEvent model.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_event_create():
|
||||
"""Test creating a calendar event."""
|
||||
# Create a room first
|
||||
room = await rooms_controller.add(
|
||||
name="test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
# Create calendar event
|
||||
now = datetime.now(timezone.utc)
|
||||
event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="test-event-123",
|
||||
title="Team Meeting",
|
||||
description="Weekly team sync",
|
||||
start_time=now + timedelta(hours=1),
|
||||
end_time=now + timedelta(hours=2),
|
||||
location=f"https://example.com/room/{room.name}",
|
||||
attendees=[
|
||||
{"email": "alice@example.com", "name": "Alice", "status": "ACCEPTED"},
|
||||
{"email": "bob@example.com", "name": "Bob", "status": "TENTATIVE"},
|
||||
],
|
||||
)
|
||||
|
||||
# Save event
|
||||
saved_event = await calendar_events_controller.upsert(event)
|
||||
|
||||
assert saved_event.ics_uid == "test-event-123"
|
||||
assert saved_event.title == "Team Meeting"
|
||||
assert saved_event.room_id == room.id
|
||||
assert len(saved_event.attendees) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_event_get_by_room():
|
||||
"""Test getting calendar events for a room."""
|
||||
# Create room
|
||||
room = await rooms_controller.add(
|
||||
name="events-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create multiple events
|
||||
for i in range(3):
|
||||
event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid=f"event-{i}",
|
||||
title=f"Meeting {i}",
|
||||
start_time=now + timedelta(hours=i),
|
||||
end_time=now + timedelta(hours=i + 1),
|
||||
)
|
||||
await calendar_events_controller.upsert(event)
|
||||
|
||||
# Get events for room
|
||||
events = await calendar_events_controller.get_by_room(room.id)
|
||||
|
||||
assert len(events) == 3
|
||||
assert all(e.room_id == room.id for e in events)
|
||||
assert events[0].title == "Meeting 0"
|
||||
assert events[1].title == "Meeting 1"
|
||||
assert events[2].title == "Meeting 2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_event_get_upcoming():
|
||||
"""Test getting upcoming events within time window."""
|
||||
# Create room
|
||||
room = await rooms_controller.add(
|
||||
name="upcoming-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create events at different times
|
||||
# Past event (should not be included)
|
||||
past_event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="past-event",
|
||||
title="Past Meeting",
|
||||
start_time=now - timedelta(hours=2),
|
||||
end_time=now - timedelta(hours=1),
|
||||
)
|
||||
await calendar_events_controller.upsert(past_event)
|
||||
|
||||
# Upcoming event within 30 minutes
|
||||
upcoming_event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="upcoming-event",
|
||||
title="Upcoming Meeting",
|
||||
start_time=now + timedelta(minutes=15),
|
||||
end_time=now + timedelta(minutes=45),
|
||||
)
|
||||
await calendar_events_controller.upsert(upcoming_event)
|
||||
|
||||
# Future event beyond 30 minutes
|
||||
future_event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="future-event",
|
||||
title="Future Meeting",
|
||||
start_time=now + timedelta(hours=2),
|
||||
end_time=now + timedelta(hours=3),
|
||||
)
|
||||
await calendar_events_controller.upsert(future_event)
|
||||
|
||||
# Get upcoming events (default 30 minutes)
|
||||
upcoming = await calendar_events_controller.get_upcoming(room.id)
|
||||
|
||||
assert len(upcoming) == 1
|
||||
assert upcoming[0].ics_uid == "upcoming-event"
|
||||
|
||||
# Get upcoming with custom window
|
||||
upcoming_extended = await calendar_events_controller.get_upcoming(
|
||||
room.id, minutes_ahead=180
|
||||
)
|
||||
|
||||
assert len(upcoming_extended) == 2
|
||||
assert upcoming_extended[0].ics_uid == "upcoming-event"
|
||||
assert upcoming_extended[1].ics_uid == "future-event"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_event_upsert():
|
||||
"""Test upserting (create/update) calendar events."""
|
||||
# Create room
|
||||
room = await rooms_controller.add(
|
||||
name="upsert-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create new event
|
||||
event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="upsert-test",
|
||||
title="Original Title",
|
||||
start_time=now,
|
||||
end_time=now + timedelta(hours=1),
|
||||
)
|
||||
|
||||
created = await calendar_events_controller.upsert(event)
|
||||
assert created.title == "Original Title"
|
||||
|
||||
# Update existing event
|
||||
event.title = "Updated Title"
|
||||
event.description = "Added description"
|
||||
|
||||
updated = await calendar_events_controller.upsert(event)
|
||||
assert updated.title == "Updated Title"
|
||||
assert updated.description == "Added description"
|
||||
assert updated.ics_uid == "upsert-test"
|
||||
|
||||
# Verify only one event exists
|
||||
events = await calendar_events_controller.get_by_room(room.id)
|
||||
assert len(events) == 1
|
||||
assert events[0].title == "Updated Title"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_event_soft_delete():
|
||||
"""Test soft deleting events no longer in calendar."""
|
||||
# Create room
|
||||
room = await rooms_controller.add(
|
||||
name="delete-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create multiple events
|
||||
for i in range(4):
|
||||
event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid=f"event-{i}",
|
||||
title=f"Meeting {i}",
|
||||
start_time=now + timedelta(hours=i),
|
||||
end_time=now + timedelta(hours=i + 1),
|
||||
)
|
||||
await calendar_events_controller.upsert(event)
|
||||
|
||||
# Soft delete events not in current list
|
||||
current_ids = ["event-0", "event-2"] # Keep events 0 and 2
|
||||
deleted_count = await calendar_events_controller.soft_delete_missing(
|
||||
room.id, current_ids
|
||||
)
|
||||
|
||||
assert deleted_count == 2 # Should delete events 1 and 3
|
||||
|
||||
# Get non-deleted events
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
room.id, include_deleted=False
|
||||
)
|
||||
assert len(events) == 2
|
||||
assert {e.ics_uid for e in events} == {"event-0", "event-2"}
|
||||
|
||||
# Get all events including deleted
|
||||
all_events = await calendar_events_controller.get_by_room(
|
||||
room.id, include_deleted=True
|
||||
)
|
||||
assert len(all_events) == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_event_past_events_not_deleted():
|
||||
"""Test that past events are not soft deleted."""
|
||||
# Create room
|
||||
room = await rooms_controller.add(
|
||||
name="past-events-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create past event
|
||||
past_event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="past-event",
|
||||
title="Past Meeting",
|
||||
start_time=now - timedelta(hours=2),
|
||||
end_time=now - timedelta(hours=1),
|
||||
)
|
||||
await calendar_events_controller.upsert(past_event)
|
||||
|
||||
# Create future event
|
||||
future_event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="future-event",
|
||||
title="Future Meeting",
|
||||
start_time=now + timedelta(hours=1),
|
||||
end_time=now + timedelta(hours=2),
|
||||
)
|
||||
await calendar_events_controller.upsert(future_event)
|
||||
|
||||
# Try to soft delete all events (only future should be deleted)
|
||||
deleted_count = await calendar_events_controller.soft_delete_missing(room.id, [])
|
||||
|
||||
assert deleted_count == 1 # Only future event deleted
|
||||
|
||||
# Verify past event still exists
|
||||
events = await calendar_events_controller.get_by_room(
|
||||
room.id, include_deleted=False
|
||||
)
|
||||
assert len(events) == 1
|
||||
assert events[0].ics_uid == "past-event"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_event_with_raw_ics_data():
|
||||
"""Test storing raw ICS data with calendar event."""
|
||||
# Create room
|
||||
room = await rooms_controller.add(
|
||||
name="raw-ics-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
raw_ics = """BEGIN:VEVENT
|
||||
UID:test-raw-123
|
||||
SUMMARY:Test Event
|
||||
DTSTART:20240101T100000Z
|
||||
DTEND:20240101T110000Z
|
||||
END:VEVENT"""
|
||||
|
||||
event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="test-raw-123",
|
||||
title="Test Event",
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
ics_raw_data=raw_ics,
|
||||
)
|
||||
|
||||
saved = await calendar_events_controller.upsert(event)
|
||||
|
||||
assert saved.ics_raw_data == raw_ics
|
||||
|
||||
# Retrieve and verify
|
||||
retrieved = await calendar_events_controller.get_by_ics_uid(room.id, "test-raw-123")
|
||||
assert retrieved is not None
|
||||
assert retrieved.ics_raw_data == raw_ics
|
||||
@@ -1,330 +0,0 @@
|
||||
"""
|
||||
Tests for GPU Modal transcription endpoints.
|
||||
|
||||
These tests are marked with the "gpu-modal" group and will not run by default.
|
||||
Run them with: pytest -m gpu-modal tests/test_gpu_modal_transcript_parakeet.py
|
||||
|
||||
Required environment variables:
|
||||
- TRANSCRIPT_URL: URL to the Modal.com endpoint (required)
|
||||
- TRANSCRIPT_MODAL_API_KEY: API key for authentication (optional)
|
||||
- TRANSCRIPT_MODEL: Model name to use (optional, defaults to nvidia/parakeet-tdt-0.6b-v2)
|
||||
|
||||
Example with pytest (override default addopts to run ONLY gpu_modal tests):
|
||||
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web-dev.modal.run \
|
||||
TRANSCRIPT_MODAL_API_KEY=your-api-key \
|
||||
uv run -m pytest -m gpu_modal --no-cov tests/test_gpu_modal_transcript.py
|
||||
|
||||
# Or with completely clean options:
|
||||
uv run -m pytest -m gpu_modal -o addopts="" tests/
|
||||
|
||||
Running Modal locally for testing:
|
||||
modal serve gpu/modal_deployments/reflector_transcriber_parakeet.py
|
||||
# This will give you a local URL like https://xxxxx--reflector-transcriber-parakeet-web-dev.modal.run to test against
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Test audio file URL for testing
|
||||
TEST_AUDIO_URL = (
|
||||
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||
)
|
||||
|
||||
|
||||
def get_modal_transcript_url():
|
||||
"""Get and validate the Modal transcript URL from environment."""
|
||||
url = os.environ.get("TRANSCRIPT_URL")
|
||||
if not url:
|
||||
pytest.skip(
|
||||
"TRANSCRIPT_URL environment variable is required for GPU Modal tests"
|
||||
)
|
||||
return url
|
||||
|
||||
|
||||
def get_auth_headers():
|
||||
"""Get authentication headers if API key is available."""
|
||||
api_key = os.environ.get("TRANSCRIPT_MODAL_API_KEY")
|
||||
if api_key:
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
return {}
|
||||
|
||||
|
||||
def get_model_name():
|
||||
"""Get the model name from environment or use default."""
|
||||
return os.environ.get("TRANSCRIPT_MODEL", "nvidia/parakeet-tdt-0.6b-v2")
|
||||
|
||||
|
||||
@pytest.mark.gpu_modal
|
||||
class TestGPUModalTranscript:
|
||||
"""Test suite for GPU Modal transcription endpoints."""
|
||||
|
||||
def test_transcriptions_from_url(self):
|
||||
"""Test the /v1/audio/transcriptions-from-url endpoint."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions-from-url",
|
||||
json={
|
||||
"audio_file_url": TEST_AUDIO_URL,
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"timestamp_offset": 0.0,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure
|
||||
assert "text" in result
|
||||
assert "words" in result
|
||||
assert isinstance(result["text"], str)
|
||||
assert isinstance(result["words"], list)
|
||||
|
||||
# Verify content is meaningful
|
||||
assert len(result["text"]) > 0, "Transcript text should not be empty"
|
||||
assert len(result["words"]) > 0, "Words list must not be empty"
|
||||
|
||||
# Verify word structure
|
||||
for word in result["words"]:
|
||||
assert "word" in word
|
||||
assert "start" in word
|
||||
assert "end" in word
|
||||
assert isinstance(word["start"], (int, float))
|
||||
assert isinstance(word["end"], (int, float))
|
||||
assert word["start"] <= word["end"]
|
||||
|
||||
def test_transcriptions_single_file(self):
|
||||
"""Test the /v1/audio/transcriptions endpoint with a single file."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
# Download test audio file to upload
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
audio_response = client.get(TEST_AUDIO_URL)
|
||||
audio_response.raise_for_status()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file:
|
||||
tmp_file.write(audio_response.content)
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
# Upload the file for transcription
|
||||
with open(tmp_file_path, "rb") as f:
|
||||
files = {"file": ("test_audio.mp3", f, "audio/mpeg")}
|
||||
data = {
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"batch": "false",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure for single file
|
||||
assert "text" in result
|
||||
assert "words" in result
|
||||
assert "filename" in result
|
||||
assert isinstance(result["text"], str)
|
||||
assert isinstance(result["words"], list)
|
||||
|
||||
# Verify content
|
||||
assert len(result["text"]) > 0, "Transcript text should not be empty"
|
||||
|
||||
finally:
|
||||
Path(tmp_file_path).unlink(missing_ok=True)
|
||||
|
||||
def test_transcriptions_multiple_files(self):
|
||||
"""Test the /v1/audio/transcriptions endpoint with multiple files (non-batch mode)."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
# Create multiple test files (we'll use the same audio content for simplicity)
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
audio_response = client.get(TEST_AUDIO_URL)
|
||||
audio_response.raise_for_status()
|
||||
audio_content = audio_response.content
|
||||
|
||||
temp_files = []
|
||||
try:
|
||||
# Create 3 temporary files
|
||||
for i in range(3):
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
|
||||
tmp_file.write(audio_content)
|
||||
tmp_file.close()
|
||||
temp_files.append(tmp_file.name)
|
||||
|
||||
# Upload multiple files for transcription (non-batch)
|
||||
files = [
|
||||
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
|
||||
for i, f in enumerate(temp_files)
|
||||
]
|
||||
data = {
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"batch": "false",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Close file handles
|
||||
for _, file_tuple in files:
|
||||
file_tuple[1].close()
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure for multiple files (non-batch)
|
||||
assert "results" in result
|
||||
assert isinstance(result["results"], list)
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
for idx, file_result in enumerate(result["results"]):
|
||||
assert "text" in file_result
|
||||
assert "words" in file_result
|
||||
assert "filename" in file_result
|
||||
assert isinstance(file_result["text"], str)
|
||||
assert isinstance(file_result["words"], list)
|
||||
assert len(file_result["text"]) > 0
|
||||
|
||||
finally:
|
||||
for f in temp_files:
|
||||
Path(f).unlink(missing_ok=True)
|
||||
|
||||
def test_transcriptions_multiple_files_batch(self):
|
||||
"""Test the /v1/audio/transcriptions endpoint with multiple files in batch mode."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
# Create multiple test files
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
audio_response = client.get(TEST_AUDIO_URL)
|
||||
audio_response.raise_for_status()
|
||||
audio_content = audio_response.content
|
||||
|
||||
temp_files = []
|
||||
try:
|
||||
# Create 3 temporary files
|
||||
for i in range(3):
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
|
||||
tmp_file.write(audio_content)
|
||||
tmp_file.close()
|
||||
temp_files.append(tmp_file.name)
|
||||
|
||||
# Upload multiple files for batch transcription
|
||||
files = [
|
||||
("files", (f"test_audio_{i}.mp3", open(f, "rb"), "audio/mpeg"))
|
||||
for i, f in enumerate(temp_files)
|
||||
]
|
||||
data = {
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"batch": "true",
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Close file handles
|
||||
for _, file_tuple in files:
|
||||
file_tuple[1].close()
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure for batch mode
|
||||
assert "results" in result
|
||||
assert isinstance(result["results"], list)
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
for idx, batch_result in enumerate(result["results"]):
|
||||
assert "text" in batch_result
|
||||
assert "words" in batch_result
|
||||
assert "filename" in batch_result
|
||||
assert isinstance(batch_result["text"], str)
|
||||
assert isinstance(batch_result["words"], list)
|
||||
assert len(batch_result["text"]) > 0
|
||||
|
||||
finally:
|
||||
for f in temp_files:
|
||||
Path(f).unlink(missing_ok=True)
|
||||
|
||||
def test_transcriptions_error_handling(self):
|
||||
"""Test error handling for invalid requests."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
# Test with unsupported language
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions-from-url",
|
||||
json={
|
||||
"audio_file_url": TEST_AUDIO_URL,
|
||||
"model": get_model_name(),
|
||||
"language": "fr", # Parakeet only supports English
|
||||
"timestamp_offset": 0.0,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "only supports English" in response.text
|
||||
|
||||
def test_transcriptions_with_timestamp_offset(self):
|
||||
"""Test transcription with timestamp offset parameter."""
|
||||
url = get_modal_transcript_url()
|
||||
headers = get_auth_headers()
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
# Test with timestamp offset
|
||||
response = client.post(
|
||||
f"{url}/v1/audio/transcriptions-from-url",
|
||||
json={
|
||||
"audio_file_url": TEST_AUDIO_URL,
|
||||
"model": get_model_name(),
|
||||
"language": "en",
|
||||
"timestamp_offset": 10.0, # Add 10 second offset
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
result = response.json()
|
||||
|
||||
# Verify response structure
|
||||
assert "text" in result
|
||||
assert "words" in result
|
||||
assert len(result["words"]) > 0, "Words list must not be empty"
|
||||
|
||||
# Verify that timestamps have been offset
|
||||
for word in result["words"]:
|
||||
# All timestamps should be >= 10.0 due to offset
|
||||
assert (
|
||||
word["start"] >= 10.0
|
||||
), f"Word start time {word['start']} should be >= 10.0"
|
||||
assert (
|
||||
word["end"] >= 10.0
|
||||
), f"Word end time {word['end']} should be >= 10.0"
|
||||
230
server/tests/test_ics_background_tasks.py
Normal file
230
server/tests/test_ics_background_tasks.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from icalendar import Calendar, Event
|
||||
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.worker.ics_sync import (
|
||||
_should_sync,
|
||||
_sync_all_ics_calendars_async,
|
||||
_sync_room_ics_async,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_room_ics_task():
|
||||
room = await rooms_controller.add(
|
||||
name="task-test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/task.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
cal = Calendar()
|
||||
event = Event()
|
||||
event.add("uid", "task-event-1")
|
||||
event.add("summary", "Task Test Meeting")
|
||||
from reflector.settings import settings
|
||||
|
||||
event.add("location", f"{settings.BASE_URL}/room/{room.name}")
|
||||
now = datetime.now(timezone.utc)
|
||||
event.add("dtstart", now + timedelta(hours=1))
|
||||
event.add("dtend", now + timedelta(hours=2))
|
||||
cal.add_component(event)
|
||||
ics_content = cal.to_ical().decode("utf-8")
|
||||
|
||||
with patch(
|
||||
"reflector.services.ics_sync.ICSFetchService.fetch_ics", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = ics_content
|
||||
|
||||
await _sync_room_ics_async(room.id)
|
||||
|
||||
events = await calendar_events_controller.get_by_room(room.id)
|
||||
assert len(events) == 1
|
||||
assert events[0].ics_uid == "task-event-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_room_ics_disabled():
|
||||
room = await rooms_controller.add(
|
||||
name="disabled-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_enabled=False,
|
||||
)
|
||||
|
||||
await _sync_room_ics_async(room.id)
|
||||
|
||||
events = await calendar_events_controller.get_by_room(room.id)
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_all_ics_calendars():
|
||||
room1 = await rooms_controller.add(
|
||||
name="sync-all-1",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/1.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
room2 = await rooms_controller.add(
|
||||
name="sync-all-2",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/2.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
room3 = await rooms_controller.add(
|
||||
name="sync-all-3",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_enabled=False,
|
||||
)
|
||||
|
||||
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
|
||||
await _sync_all_ics_calendars_async()
|
||||
|
||||
assert mock_delay.call_count == 2
|
||||
called_room_ids = [call.args[0] for call in mock_delay.call_args_list]
|
||||
assert room1.id in called_room_ids
|
||||
assert room2.id in called_room_ids
|
||||
assert room3.id not in called_room_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_sync_logic():
|
||||
room = MagicMock()
|
||||
|
||||
room.ics_last_sync = None
|
||||
assert _should_sync(room) is True
|
||||
|
||||
room.ics_last_sync = datetime.now(timezone.utc) - timedelta(seconds=100)
|
||||
room.ics_fetch_interval = 300
|
||||
assert _should_sync(room) is False
|
||||
|
||||
room.ics_last_sync = datetime.now(timezone.utc) - timedelta(seconds=400)
|
||||
room.ics_fetch_interval = 300
|
||||
assert _should_sync(room) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_respects_fetch_interval():
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
room1 = await rooms_controller.add(
|
||||
name="interval-test-1",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/interval.ics",
|
||||
ics_enabled=True,
|
||||
ics_fetch_interval=300,
|
||||
)
|
||||
|
||||
await rooms_controller.update(
|
||||
room1,
|
||||
{"ics_last_sync": now - timedelta(seconds=100)},
|
||||
)
|
||||
|
||||
room2 = await rooms_controller.add(
|
||||
name="interval-test-2",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/interval2.ics",
|
||||
ics_enabled=True,
|
||||
ics_fetch_interval=60,
|
||||
)
|
||||
|
||||
await rooms_controller.update(
|
||||
room2,
|
||||
{"ics_last_sync": now - timedelta(seconds=100)},
|
||||
)
|
||||
|
||||
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
|
||||
await _sync_all_ics_calendars_async()
|
||||
|
||||
assert mock_delay.call_count == 1
|
||||
assert mock_delay.call_args[0][0] == room2.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_handles_errors_gracefully():
|
||||
room = await rooms_controller.add(
|
||||
name="error-task-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/error.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.services.ics_sync.ICSFetchService.fetch_ics", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.side_effect = Exception("Network error")
|
||||
|
||||
await _sync_room_ics_async(room.id)
|
||||
|
||||
events = await calendar_events_controller.get_by_room(room.id)
|
||||
assert len(events) == 0
|
||||
289
server/tests/test_ics_sync.py
Normal file
289
server/tests/test_ics_sync.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from icalendar import Calendar, Event
|
||||
|
||||
from reflector.db.calendar_events import calendar_events_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
from reflector.services.ics_sync import ICSFetchService, ICSSyncService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ics_fetch_service_event_matching():
|
||||
service = ICSFetchService()
|
||||
room_name = "test-room"
|
||||
room_url = "https://example.com/room/test-room"
|
||||
|
||||
# Create test event
|
||||
event = Event()
|
||||
event.add("uid", "test-123")
|
||||
event.add("summary", "Test Meeting")
|
||||
|
||||
# Test matching with full URL in location
|
||||
event.add("location", "https://example.com/room/test-room")
|
||||
assert service._event_matches_room(event, room_name, room_url) is True
|
||||
|
||||
# Test matching with URL without protocol
|
||||
event["location"] = "example.com/room/test-room"
|
||||
assert service._event_matches_room(event, room_name, room_url) is True
|
||||
|
||||
# Test matching in description
|
||||
event["location"] = "Conference Room A"
|
||||
event.add("description", f"Join at {room_url}")
|
||||
assert service._event_matches_room(event, room_name, room_url) is True
|
||||
|
||||
# Test non-matching
|
||||
event["location"] = "Different Room"
|
||||
event["description"] = "No room URL here"
|
||||
assert service._event_matches_room(event, room_name, room_url) is False
|
||||
|
||||
# Test partial paths should NOT match anymore
|
||||
event["location"] = "/room/test-room"
|
||||
assert service._event_matches_room(event, room_name, room_url) is False
|
||||
|
||||
event["location"] = f"Room: {room_name}"
|
||||
assert service._event_matches_room(event, room_name, room_url) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ics_fetch_service_parse_event():
|
||||
service = ICSFetchService()
|
||||
|
||||
# Create test event
|
||||
event = Event()
|
||||
event.add("uid", "test-456")
|
||||
event.add("summary", "Team Standup")
|
||||
event.add("description", "Daily team sync")
|
||||
event.add("location", "https://example.com/room/standup")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
event.add("dtstart", now)
|
||||
event.add("dtend", now + timedelta(hours=1))
|
||||
|
||||
# Add attendees
|
||||
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
|
||||
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
|
||||
event.add("organizer", "mailto:carol@example.com", parameters={"CN": "Carol"})
|
||||
|
||||
# Parse event
|
||||
result = service._parse_event(event)
|
||||
|
||||
assert result is not None
|
||||
assert result["ics_uid"] == "test-456"
|
||||
assert result["title"] == "Team Standup"
|
||||
assert result["description"] == "Daily team sync"
|
||||
assert result["location"] == "https://example.com/room/standup"
|
||||
assert len(result["attendees"]) == 3 # 2 attendees + 1 organizer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ics_fetch_service_extract_room_events():
|
||||
service = ICSFetchService()
|
||||
room_name = "meeting"
|
||||
room_url = "https://example.com/room/meeting"
|
||||
|
||||
# Create calendar with multiple events
|
||||
cal = Calendar()
|
||||
|
||||
# Event 1: Matches room
|
||||
event1 = Event()
|
||||
event1.add("uid", "match-1")
|
||||
event1.add("summary", "Planning Meeting")
|
||||
event1.add("location", room_url)
|
||||
now = datetime.now(timezone.utc)
|
||||
event1.add("dtstart", now + timedelta(hours=2))
|
||||
event1.add("dtend", now + timedelta(hours=3))
|
||||
cal.add_component(event1)
|
||||
|
||||
# Event 2: Doesn't match room
|
||||
event2 = Event()
|
||||
event2.add("uid", "no-match")
|
||||
event2.add("summary", "Other Meeting")
|
||||
event2.add("location", "https://example.com/room/other")
|
||||
event2.add("dtstart", now + timedelta(hours=4))
|
||||
event2.add("dtend", now + timedelta(hours=5))
|
||||
cal.add_component(event2)
|
||||
|
||||
# Event 3: Matches room in description
|
||||
event3 = Event()
|
||||
event3.add("uid", "match-2")
|
||||
event3.add("summary", "Review Session")
|
||||
event3.add("description", f"Meeting link: {room_url}")
|
||||
event3.add("dtstart", now + timedelta(hours=6))
|
||||
event3.add("dtend", now + timedelta(hours=7))
|
||||
cal.add_component(event3)
|
||||
|
||||
# Event 4: Cancelled event (should be skipped)
|
||||
event4 = Event()
|
||||
event4.add("uid", "cancelled")
|
||||
event4.add("summary", "Cancelled Meeting")
|
||||
event4.add("location", room_url)
|
||||
event4.add("status", "CANCELLED")
|
||||
event4.add("dtstart", now + timedelta(hours=8))
|
||||
event4.add("dtend", now + timedelta(hours=9))
|
||||
cal.add_component(event4)
|
||||
|
||||
# Extract events
|
||||
events = service.extract_room_events(cal, room_name, room_url)
|
||||
|
||||
assert len(events) == 2
|
||||
assert events[0]["ics_uid"] == "match-1"
|
||||
assert events[1]["ics_uid"] == "match-2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ics_sync_service_sync_room_calendar():
|
||||
# Create room
|
||||
room = await rooms_controller.add(
|
||||
name="sync-test",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/test.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
# Mock ICS content
|
||||
cal = Calendar()
|
||||
event = Event()
|
||||
event.add("uid", "sync-event-1")
|
||||
event.add("summary", "Sync Test Meeting")
|
||||
# Use the actual BASE_URL from settings
|
||||
from reflector.settings import settings
|
||||
|
||||
event.add("location", f"{settings.BASE_URL}/room/{room.name}")
|
||||
now = datetime.now(timezone.utc)
|
||||
event.add("dtstart", now + timedelta(hours=1))
|
||||
event.add("dtend", now + timedelta(hours=2))
|
||||
cal.add_component(event)
|
||||
ics_content = cal.to_ical().decode("utf-8")
|
||||
|
||||
# Create sync service and mock fetch
|
||||
sync_service = ICSSyncService()
|
||||
|
||||
with patch.object(
|
||||
sync_service.fetch_service, "fetch_ics", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = ics_content
|
||||
|
||||
# First sync
|
||||
result = await sync_service.sync_room_calendar(room)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["events_found"] == 1
|
||||
assert result["events_created"] == 1
|
||||
assert result["events_updated"] == 0
|
||||
assert result["events_deleted"] == 0
|
||||
|
||||
# Verify event was created
|
||||
events = await calendar_events_controller.get_by_room(room.id)
|
||||
assert len(events) == 1
|
||||
assert events[0].ics_uid == "sync-event-1"
|
||||
assert events[0].title == "Sync Test Meeting"
|
||||
|
||||
# Second sync with same content (should be unchanged)
|
||||
# Refresh room to get updated etag and force sync by setting old sync time
|
||||
room = await rooms_controller.get_by_id(room.id)
|
||||
await rooms_controller.update(
|
||||
room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)}
|
||||
)
|
||||
result = await sync_service.sync_room_calendar(room)
|
||||
assert result["status"] == "unchanged"
|
||||
|
||||
# Third sync with updated event
|
||||
event["summary"] = "Updated Meeting Title"
|
||||
cal = Calendar()
|
||||
cal.add_component(event)
|
||||
ics_content = cal.to_ical().decode("utf-8")
|
||||
mock_fetch.return_value = ics_content
|
||||
|
||||
# Force sync by clearing etag
|
||||
await rooms_controller.update(room, {"ics_last_etag": None})
|
||||
|
||||
result = await sync_service.sync_room_calendar(room)
|
||||
assert result["status"] == "success"
|
||||
assert result["events_created"] == 0
|
||||
assert result["events_updated"] == 1
|
||||
|
||||
# Verify event was updated
|
||||
events = await calendar_events_controller.get_by_room(room.id)
|
||||
assert len(events) == 1
|
||||
assert events[0].title == "Updated Meeting Title"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ics_sync_service_should_sync():
|
||||
service = ICSSyncService()
|
||||
|
||||
# Room never synced
|
||||
room = MagicMock()
|
||||
room.ics_last_sync = None
|
||||
room.ics_fetch_interval = 300
|
||||
assert service._should_sync(room) is True
|
||||
|
||||
# Room synced recently
|
||||
room.ics_last_sync = datetime.now(timezone.utc) - timedelta(seconds=100)
|
||||
assert service._should_sync(room) is False
|
||||
|
||||
# Room sync due
|
||||
room.ics_last_sync = datetime.now(timezone.utc) - timedelta(seconds=400)
|
||||
assert service._should_sync(room) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ics_sync_service_skip_disabled():
|
||||
service = ICSSyncService()
|
||||
|
||||
# Room with ICS disabled
|
||||
room = MagicMock()
|
||||
room.ics_enabled = False
|
||||
room.ics_url = "https://calendar.example.com/test.ics"
|
||||
|
||||
result = await service.sync_room_calendar(room)
|
||||
assert result["status"] == "skipped"
|
||||
assert result["reason"] == "ICS not configured"
|
||||
|
||||
# Room without URL
|
||||
room.ics_enabled = True
|
||||
room.ics_url = None
|
||||
|
||||
result = await service.sync_room_calendar(room)
|
||||
assert result["status"] == "skipped"
|
||||
assert result["reason"] == "ICS not configured"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ics_sync_service_error_handling():
|
||||
# Create room
|
||||
room = await rooms_controller.add(
|
||||
name="error-test",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/error.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
sync_service = ICSSyncService()
|
||||
|
||||
with patch.object(
|
||||
sync_service.fetch_service, "fetch_ics", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.side_effect = Exception("Network error")
|
||||
|
||||
result = await sync_service.sync_room_calendar(room)
|
||||
assert result["status"] == "error"
|
||||
assert "Network error" in result["error"]
|
||||
283
server/tests/test_multiple_active_meetings.py
Normal file
283
server/tests/test_multiple_active_meetings.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""Tests for multiple active meetings per room functionality."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
|
||||
from reflector.db.meetings import meetings_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_active_meetings_per_room():
|
||||
"""Test that multiple active meetings can exist for the same room."""
|
||||
# Create a room
|
||||
room = await rooms_controller.add(
|
||||
name="test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
end_time = current_time + timedelta(hours=2)
|
||||
|
||||
# Create first meeting
|
||||
meeting1 = await meetings_controller.create(
|
||||
id="meeting-1",
|
||||
room_name="test-meeting-1",
|
||||
room_url="https://whereby.com/test-1",
|
||||
host_room_url="https://whereby.com/test-1-host",
|
||||
start_date=current_time,
|
||||
end_date=end_time,
|
||||
user_id="test-user",
|
||||
room=room,
|
||||
)
|
||||
|
||||
# Create second meeting for the same room (should succeed now)
|
||||
meeting2 = await meetings_controller.create(
|
||||
id="meeting-2",
|
||||
room_name="test-meeting-2",
|
||||
room_url="https://whereby.com/test-2",
|
||||
host_room_url="https://whereby.com/test-2-host",
|
||||
start_date=current_time,
|
||||
end_date=end_time,
|
||||
user_id="test-user",
|
||||
room=room,
|
||||
)
|
||||
|
||||
# Both meetings should be active
|
||||
active_meetings = await meetings_controller.get_all_active_for_room(
|
||||
room=room, current_time=current_time
|
||||
)
|
||||
|
||||
assert len(active_meetings) == 2
|
||||
assert meeting1.id in [m.id for m in active_meetings]
|
||||
assert meeting2.id in [m.id for m in active_meetings]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_by_calendar_event():
|
||||
"""Test getting active meeting by calendar event ID."""
|
||||
# Create a room
|
||||
room = await rooms_controller.add(
|
||||
name="test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
# Create a calendar event
|
||||
event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="test-event-uid",
|
||||
title="Test Meeting",
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
)
|
||||
event = await calendar_events_controller.upsert(event)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
end_time = current_time + timedelta(hours=2)
|
||||
|
||||
# Create meeting linked to calendar event
|
||||
meeting = await meetings_controller.create(
|
||||
id="meeting-cal-1",
|
||||
room_name="test-meeting-cal",
|
||||
room_url="https://whereby.com/test-cal",
|
||||
host_room_url="https://whereby.com/test-cal-host",
|
||||
start_date=current_time,
|
||||
end_date=end_time,
|
||||
user_id="test-user",
|
||||
room=room,
|
||||
calendar_event_id=event.id,
|
||||
calendar_metadata={"title": event.title},
|
||||
)
|
||||
|
||||
# Should find the meeting by calendar event
|
||||
found_meeting = await meetings_controller.get_active_by_calendar_event(
|
||||
room=room, calendar_event_id=event.id, current_time=current_time
|
||||
)
|
||||
|
||||
assert found_meeting is not None
|
||||
assert found_meeting.id == meeting.id
|
||||
assert found_meeting.calendar_event_id == event.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grace_period_logic():
|
||||
"""Test that meetings have a grace period after last participant leaves."""
|
||||
# Create a room
|
||||
room = await rooms_controller.add(
|
||||
name="test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
end_time = current_time + timedelta(hours=2)
|
||||
|
||||
# Create meeting
|
||||
meeting = await meetings_controller.create(
|
||||
id="meeting-grace",
|
||||
room_name="test-meeting-grace",
|
||||
room_url="https://whereby.com/test-grace",
|
||||
host_room_url="https://whereby.com/test-grace-host",
|
||||
start_date=current_time,
|
||||
end_date=end_time,
|
||||
user_id="test-user",
|
||||
room=room,
|
||||
)
|
||||
|
||||
# Test grace period logic by simulating different states
|
||||
|
||||
# Simulate first time all participants left
|
||||
await meetings_controller.update_meeting(
|
||||
meeting.id, num_clients=0, last_participant_left_at=current_time
|
||||
)
|
||||
|
||||
# Within grace period (10 min) - should still be active
|
||||
await meetings_controller.update_meeting(
|
||||
meeting.id, last_participant_left_at=current_time - timedelta(minutes=10)
|
||||
)
|
||||
|
||||
updated_meeting = await meetings_controller.get_by_id(meeting.id)
|
||||
assert updated_meeting.is_active is True # Still active during grace period
|
||||
|
||||
# Simulate grace period expired (20 min) and deactivate
|
||||
await meetings_controller.update_meeting(
|
||||
meeting.id, last_participant_left_at=current_time - timedelta(minutes=20)
|
||||
)
|
||||
|
||||
# Manually test the grace period logic that would be in process_meetings
|
||||
updated_meeting = await meetings_controller.get_by_id(meeting.id)
|
||||
if updated_meeting.last_participant_left_at:
|
||||
grace_period = timedelta(minutes=updated_meeting.grace_period_minutes)
|
||||
if current_time > updated_meeting.last_participant_left_at + grace_period:
|
||||
await meetings_controller.update_meeting(meeting.id, is_active=False)
|
||||
|
||||
updated_meeting = await meetings_controller.get_by_id(meeting.id)
|
||||
assert updated_meeting.is_active is False # Now deactivated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calendar_meeting_force_close_after_30_min():
|
||||
"""Test that calendar meetings force close 30 minutes after scheduled end."""
|
||||
# Create a room
|
||||
room = await rooms_controller.add(
|
||||
name="test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
# Create a calendar event
|
||||
event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="test-event-force",
|
||||
title="Test Meeting Force Close",
|
||||
start_time=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
end_time=datetime.now(timezone.utc) - timedelta(minutes=35), # Ended 35 min ago
|
||||
)
|
||||
event = await calendar_events_controller.upsert(event)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Create meeting linked to calendar event
|
||||
meeting = await meetings_controller.create(
|
||||
id="meeting-force",
|
||||
room_name="test-meeting-force",
|
||||
room_url="https://whereby.com/test-force",
|
||||
host_room_url="https://whereby.com/test-force-host",
|
||||
start_date=event.start_time,
|
||||
end_date=event.end_time,
|
||||
user_id="test-user",
|
||||
room=room,
|
||||
calendar_event_id=event.id,
|
||||
)
|
||||
|
||||
# Test that calendar meetings force close 30 min after scheduled end
|
||||
# The meeting ended 35 minutes ago, so it should be force closed
|
||||
|
||||
# Manually test the force close logic that would be in process_meetings
|
||||
if meeting.calendar_event_id:
|
||||
if current_time > meeting.end_date + timedelta(minutes=30):
|
||||
await meetings_controller.update_meeting(meeting.id, is_active=False)
|
||||
|
||||
updated_meeting = await meetings_controller.get_by_id(meeting.id)
|
||||
assert updated_meeting.is_active is False # Force closed after 30 min
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_participant_rejoin_clears_grace_period():
|
||||
"""Test that participant rejoining clears the grace period."""
|
||||
# Create a room
|
||||
room = await rooms_controller.add(
|
||||
name="test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
end_time = current_time + timedelta(hours=2)
|
||||
|
||||
# Create meeting with grace period already set
|
||||
meeting = await meetings_controller.create(
|
||||
id="meeting-rejoin",
|
||||
room_name="test-meeting-rejoin",
|
||||
room_url="https://whereby.com/test-rejoin",
|
||||
host_room_url="https://whereby.com/test-rejoin-host",
|
||||
start_date=current_time,
|
||||
end_date=end_time,
|
||||
user_id="test-user",
|
||||
room=room,
|
||||
)
|
||||
|
||||
# Set last_participant_left_at to simulate grace period
|
||||
await meetings_controller.update_meeting(
|
||||
meeting.id,
|
||||
last_participant_left_at=current_time - timedelta(minutes=5),
|
||||
num_clients=0,
|
||||
)
|
||||
|
||||
# Simulate participant rejoining - clear grace period
|
||||
await meetings_controller.update_meeting(
|
||||
meeting.id, last_participant_left_at=None, num_clients=1
|
||||
)
|
||||
|
||||
updated_meeting = await meetings_controller.get_by_id(meeting.id)
|
||||
assert updated_meeting.last_participant_left_at is None # Grace period cleared
|
||||
assert updated_meeting.is_active is True # Still active
|
||||
@@ -1,633 +0,0 @@
|
||||
"""
|
||||
Tests for PipelineMainFile - file-based processing pipeline
|
||||
|
||||
This test verifies the complete file processing pipeline without mocking much,
|
||||
ensuring all processors are correctly invoked and the happy path works correctly.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
from reflector.processors.file_diarization import FileDiarizationOutput
|
||||
from reflector.processors.types import (
|
||||
DiarizationSegment,
|
||||
TitleSummary,
|
||||
Word,
|
||||
)
|
||||
from reflector.processors.types import (
|
||||
Transcript as TranscriptType,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_file_transcript():
|
||||
"""Mock FileTranscriptAutoProcessor for file processing"""
|
||||
from reflector.processors.file_transcript import FileTranscriptProcessor
|
||||
|
||||
class TestFileTranscriptProcessor(FileTranscriptProcessor):
|
||||
async def _transcript(self, data):
|
||||
return TranscriptType(
|
||||
text="Hello world. How are you today?",
|
||||
words=[
|
||||
Word(start=0.0, end=0.5, text="Hello", speaker=0),
|
||||
Word(start=0.5, end=0.6, text=" ", speaker=0),
|
||||
Word(start=0.6, end=1.0, text="world", speaker=0),
|
||||
Word(start=1.0, end=1.1, text=".", speaker=0),
|
||||
Word(start=1.1, end=1.2, text=" ", speaker=0),
|
||||
Word(start=1.2, end=1.5, text="How", speaker=0),
|
||||
Word(start=1.5, end=1.6, text=" ", speaker=0),
|
||||
Word(start=1.6, end=1.8, text="are", speaker=0),
|
||||
Word(start=1.8, end=1.9, text=" ", speaker=0),
|
||||
Word(start=1.9, end=2.1, text="you", speaker=0),
|
||||
Word(start=2.1, end=2.2, text=" ", speaker=0),
|
||||
Word(start=2.2, end=2.5, text="today", speaker=0),
|
||||
Word(start=2.5, end=2.6, text="?", speaker=0),
|
||||
],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_transcript_auto.FileTranscriptAutoProcessor.__new__"
|
||||
) as mock_auto:
|
||||
mock_auto.return_value = TestFileTranscriptProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_file_diarization():
|
||||
"""Mock FileDiarizationAutoProcessor for file processing"""
|
||||
from reflector.processors.file_diarization import FileDiarizationProcessor
|
||||
|
||||
class TestFileDiarizationProcessor(FileDiarizationProcessor):
|
||||
async def _diarize(self, data):
|
||||
return FileDiarizationOutput(
|
||||
diarization=[
|
||||
DiarizationSegment(start=0.0, end=1.1, speaker=0),
|
||||
DiarizationSegment(start=1.2, end=2.6, speaker=1),
|
||||
]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"reflector.processors.file_diarization_auto.FileDiarizationAutoProcessor.__new__"
|
||||
) as mock_auto:
|
||||
mock_auto.return_value = TestFileDiarizationProcessor()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_transcript_in_db(tmpdir):
|
||||
"""Create a mock transcript in the database"""
|
||||
from reflector.db.transcripts import Transcript
|
||||
from reflector.settings import settings
|
||||
|
||||
# Set the DATA_DIR to our tmpdir
|
||||
original_data_dir = settings.DATA_DIR
|
||||
settings.DATA_DIR = str(tmpdir)
|
||||
|
||||
transcript_id = str(uuid4())
|
||||
data_path = Path(tmpdir) / transcript_id
|
||||
data_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create mock transcript object
|
||||
transcript = Transcript(
|
||||
id=transcript_id,
|
||||
name="Test Transcript",
|
||||
status="processing",
|
||||
source_kind="file",
|
||||
source_language="en",
|
||||
target_language="en",
|
||||
)
|
||||
|
||||
# Mock the controller to return our transcript
|
||||
try:
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
||||
) as mock_get:
|
||||
mock_get.return_value = transcript
|
||||
with patch(
|
||||
"reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
|
||||
) as mock_get2:
|
||||
mock_get2.return_value = transcript
|
||||
with patch(
|
||||
"reflector.pipelines.main_live_pipeline.transcripts_controller.update"
|
||||
) as mock_update:
|
||||
mock_update.return_value = None
|
||||
yield transcript
|
||||
finally:
|
||||
# Restore original DATA_DIR
|
||||
settings.DATA_DIR = original_data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_storage():
|
||||
"""Mock storage for file uploads"""
|
||||
from reflector.storage.base import Storage
|
||||
|
||||
class TestStorage(Storage):
|
||||
async def _put_file(self, path, data):
|
||||
return None
|
||||
|
||||
async def _get_file_url(self, path):
|
||||
return f"http://test-storage/{path}"
|
||||
|
||||
async def _get_file(self, path):
|
||||
return b"test_audio_data"
|
||||
|
||||
async def _delete_file(self, path):
|
||||
return None
|
||||
|
||||
storage = TestStorage()
|
||||
# Add mock tracking for verification
|
||||
storage._put_file = AsyncMock(side_effect=storage._put_file)
|
||||
storage._get_file_url = AsyncMock(side_effect=storage._get_file_url)
|
||||
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.get_transcripts_storage"
|
||||
) as mock_get:
|
||||
mock_get.return_value = storage
|
||||
yield storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_audio_file_writer():
|
||||
"""Mock AudioFileWriterProcessor to avoid actual file writing"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.AudioFileWriterProcessor"
|
||||
) as mock_writer_class:
|
||||
mock_writer = AsyncMock()
|
||||
mock_writer.push = AsyncMock()
|
||||
mock_writer.flush = AsyncMock()
|
||||
mock_writer_class.return_value = mock_writer
|
||||
yield mock_writer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_waveform_processor():
|
||||
"""Mock AudioWaveformProcessor"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.AudioWaveformProcessor"
|
||||
) as mock_waveform_class:
|
||||
mock_waveform = AsyncMock()
|
||||
mock_waveform.set_pipeline = MagicMock()
|
||||
mock_waveform.flush = AsyncMock()
|
||||
mock_waveform_class.return_value = mock_waveform
|
||||
yield mock_waveform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_topic_detector():
|
||||
"""Mock TranscriptTopicDetectorProcessor"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.TranscriptTopicDetectorProcessor"
|
||||
) as mock_topic_class:
|
||||
mock_topic = AsyncMock()
|
||||
mock_topic.set_pipeline = MagicMock()
|
||||
mock_topic.push = AsyncMock()
|
||||
mock_topic.flush_called = False
|
||||
|
||||
# When flush is called, simulate topic detection by calling the callback
|
||||
async def flush_with_callback():
|
||||
mock_topic.flush_called = True
|
||||
if hasattr(mock_topic, "_callback"):
|
||||
# Create a minimal transcript for the TitleSummary
|
||||
test_transcript = TranscriptType(words=[], text="test transcript")
|
||||
await mock_topic._callback(
|
||||
TitleSummary(
|
||||
title="Test Topic",
|
||||
summary="Test topic summary",
|
||||
timestamp=0.0,
|
||||
duration=10.0,
|
||||
transcript=test_transcript,
|
||||
)
|
||||
)
|
||||
|
||||
mock_topic.flush = flush_with_callback
|
||||
|
||||
def init_with_callback(callback=None):
|
||||
mock_topic._callback = callback
|
||||
return mock_topic
|
||||
|
||||
mock_topic_class.side_effect = init_with_callback
|
||||
yield mock_topic
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_title_processor():
|
||||
"""Mock TranscriptFinalTitleProcessor"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.TranscriptFinalTitleProcessor"
|
||||
) as mock_title_class:
|
||||
mock_title = AsyncMock()
|
||||
mock_title.set_pipeline = MagicMock()
|
||||
mock_title.push = AsyncMock()
|
||||
mock_title.flush_called = False
|
||||
|
||||
# When flush is called, simulate title generation by calling the callback
|
||||
async def flush_with_callback():
|
||||
mock_title.flush_called = True
|
||||
if hasattr(mock_title, "_callback"):
|
||||
from reflector.processors.types import FinalTitle
|
||||
|
||||
await mock_title._callback(FinalTitle(title="Test Title"))
|
||||
|
||||
mock_title.flush = flush_with_callback
|
||||
|
||||
def init_with_callback(callback=None):
|
||||
mock_title._callback = callback
|
||||
return mock_title
|
||||
|
||||
mock_title_class.side_effect = init_with_callback
|
||||
yield mock_title
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_summary_processor():
|
||||
"""Mock TranscriptFinalSummaryProcessor"""
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.TranscriptFinalSummaryProcessor"
|
||||
) as mock_summary_class:
|
||||
mock_summary = AsyncMock()
|
||||
mock_summary.set_pipeline = MagicMock()
|
||||
mock_summary.push = AsyncMock()
|
||||
mock_summary.flush_called = False
|
||||
|
||||
# When flush is called, simulate summary generation by calling the callbacks
|
||||
async def flush_with_callback():
|
||||
mock_summary.flush_called = True
|
||||
from reflector.processors.types import FinalLongSummary, FinalShortSummary
|
||||
|
||||
if hasattr(mock_summary, "_callback"):
|
||||
await mock_summary._callback(
|
||||
FinalLongSummary(long_summary="Test long summary", duration=10.0)
|
||||
)
|
||||
if hasattr(mock_summary, "_on_short_summary"):
|
||||
await mock_summary._on_short_summary(
|
||||
FinalShortSummary(short_summary="Test short summary", duration=10.0)
|
||||
)
|
||||
|
||||
mock_summary.flush = flush_with_callback
|
||||
|
||||
def init_with_callback(transcript=None, callback=None, on_short_summary=None):
|
||||
mock_summary._callback = callback
|
||||
mock_summary._on_short_summary = on_short_summary
|
||||
return mock_summary
|
||||
|
||||
mock_summary_class.side_effect = init_with_callback
|
||||
yield mock_summary
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_main_file_process(
|
||||
tmpdir,
|
||||
mock_transcript_in_db,
|
||||
dummy_file_transcript,
|
||||
dummy_file_diarization,
|
||||
mock_storage,
|
||||
mock_audio_file_writer,
|
||||
mock_waveform_processor,
|
||||
mock_topic_detector,
|
||||
mock_title_processor,
|
||||
mock_summary_processor,
|
||||
):
|
||||
"""
|
||||
Test the complete PipelineMainFile processing pipeline.
|
||||
|
||||
This test verifies:
|
||||
1. Audio extraction and writing
|
||||
2. Audio upload to storage
|
||||
3. Parallel processing of transcription, diarization, and waveform
|
||||
4. Assembly of transcript with diarization
|
||||
5. Topic detection
|
||||
6. Title and summary generation
|
||||
"""
|
||||
# Create a test audio file
|
||||
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
|
||||
# Copy test audio to the transcript's data path as if it was uploaded
|
||||
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||
|
||||
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
|
||||
# Since we're mocking AudioFileWriterProcessor, we need to create this manually
|
||||
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||
mp3_path.write_bytes(b"mock_mp3_data")
|
||||
|
||||
# Track callback invocations
|
||||
callback_marks = {
|
||||
"on_status": [],
|
||||
"on_duration": [],
|
||||
"on_waveform": [],
|
||||
"on_topic": [],
|
||||
"on_title": [],
|
||||
"on_long_summary": [],
|
||||
"on_short_summary": [],
|
||||
}
|
||||
|
||||
# Create pipeline with mocked callbacks
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
|
||||
# Override callbacks to track invocations
|
||||
async def track_callback(name, data):
|
||||
callback_marks[name].append(data)
|
||||
# Call the original callback
|
||||
original = getattr(PipelineMainFile, name)
|
||||
return await original(pipeline, data)
|
||||
|
||||
for callback_name in callback_marks.keys():
|
||||
setattr(
|
||||
pipeline,
|
||||
callback_name,
|
||||
lambda data, n=callback_name: track_callback(n, data),
|
||||
)
|
||||
|
||||
# Mock av.open for audio processing
|
||||
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||
# Mock container for checking video streams
|
||||
mock_container = MagicMock()
|
||||
mock_container.streams.video = [] # No video streams (audio only)
|
||||
mock_container.close = MagicMock()
|
||||
|
||||
# Mock container for decoding audio frames
|
||||
mock_decode_container = MagicMock()
|
||||
mock_decode_container.decode.return_value = iter(
|
||||
[MagicMock()]
|
||||
) # One mock audio frame
|
||||
mock_decode_container.close = MagicMock()
|
||||
|
||||
# Return different containers for different calls
|
||||
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||
|
||||
# Run the pipeline
|
||||
await pipeline.process(upload_path)
|
||||
|
||||
# Verify audio extraction and writing
|
||||
assert mock_audio_file_writer.push.called
|
||||
assert mock_audio_file_writer.flush.called
|
||||
|
||||
# Verify storage upload
|
||||
assert mock_storage._put_file.called
|
||||
assert mock_storage._get_file_url.called
|
||||
|
||||
# Verify waveform generation
|
||||
assert mock_waveform_processor.flush.called
|
||||
assert mock_waveform_processor.set_pipeline.called
|
||||
|
||||
# Verify topic detection
|
||||
assert mock_topic_detector.push.called
|
||||
assert mock_topic_detector.flush_called
|
||||
|
||||
# Verify title generation
|
||||
assert mock_title_processor.push.called
|
||||
assert mock_title_processor.flush_called
|
||||
|
||||
# Verify summary generation
|
||||
assert mock_summary_processor.push.called
|
||||
assert mock_summary_processor.flush_called
|
||||
|
||||
# Verify callbacks were invoked
|
||||
assert len(callback_marks["on_topic"]) > 0, "Topic callback should be invoked"
|
||||
assert len(callback_marks["on_title"]) > 0, "Title callback should be invoked"
|
||||
assert (
|
||||
len(callback_marks["on_long_summary"]) > 0
|
||||
), "Long summary callback should be invoked"
|
||||
assert (
|
||||
len(callback_marks["on_short_summary"]) > 0
|
||||
), "Short summary callback should be invoked"
|
||||
|
||||
print(f"Callback marks: {callback_marks}")
|
||||
|
||||
# Verify the pipeline completed successfully
|
||||
assert pipeline.logger is not None
|
||||
print("PipelineMainFile test completed successfully!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_main_file_with_video(
|
||||
tmpdir,
|
||||
mock_transcript_in_db,
|
||||
dummy_file_transcript,
|
||||
dummy_file_diarization,
|
||||
mock_storage,
|
||||
mock_audio_file_writer,
|
||||
mock_waveform_processor,
|
||||
mock_topic_detector,
|
||||
mock_title_processor,
|
||||
mock_summary_processor,
|
||||
):
|
||||
"""
|
||||
Test PipelineMainFile with video input (verifies audio extraction).
|
||||
"""
|
||||
# Create a test audio file
|
||||
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
|
||||
# Copy test audio to the transcript's data path as if it was a video upload
|
||||
upload_path = mock_transcript_in_db.data_path / "upload.mp4"
|
||||
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||
|
||||
# Also create the audio.mp3 file that would be created by AudioFileWriterProcessor
|
||||
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||
mp3_path.write_bytes(b"mock_mp3_data")
|
||||
|
||||
# Create pipeline
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
|
||||
# Mock av.open for video processing
|
||||
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||
# Mock container for checking video streams
|
||||
mock_container = MagicMock()
|
||||
mock_container.streams.video = [MagicMock()] # Has video streams
|
||||
mock_container.close = MagicMock()
|
||||
|
||||
# Mock container for decoding audio frames
|
||||
mock_decode_container = MagicMock()
|
||||
mock_decode_container.decode.return_value = iter(
|
||||
[MagicMock()]
|
||||
) # One mock audio frame
|
||||
mock_decode_container.close = MagicMock()
|
||||
|
||||
# Return different containers for different calls
|
||||
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||
|
||||
# Run the pipeline
|
||||
await pipeline.process(upload_path)
|
||||
|
||||
# Verify audio extraction from video
|
||||
assert mock_audio_file_writer.push.called
|
||||
assert mock_audio_file_writer.flush.called
|
||||
|
||||
# Verify the rest of the pipeline completed
|
||||
assert mock_storage._put_file.called
|
||||
assert mock_waveform_processor.flush.called
|
||||
assert mock_topic_detector.push.called
|
||||
assert mock_title_processor.push.called
|
||||
assert mock_summary_processor.push.called
|
||||
|
||||
print("PipelineMainFile video test completed successfully!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_main_file_no_diarization(
|
||||
tmpdir,
|
||||
mock_transcript_in_db,
|
||||
dummy_file_transcript,
|
||||
mock_storage,
|
||||
mock_audio_file_writer,
|
||||
mock_waveform_processor,
|
||||
mock_topic_detector,
|
||||
mock_title_processor,
|
||||
mock_summary_processor,
|
||||
):
|
||||
"""
|
||||
Test PipelineMainFile with diarization disabled.
|
||||
"""
|
||||
from reflector.settings import settings
|
||||
|
||||
# Disable diarization
|
||||
with patch.object(settings, "DIARIZATION_BACKEND", None):
|
||||
# Create a test audio file
|
||||
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
|
||||
# Copy test audio to the transcript's data path
|
||||
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||
|
||||
# Also create the audio.mp3 file
|
||||
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||
mp3_path.write_bytes(b"mock_mp3_data")
|
||||
|
||||
# Create pipeline
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
|
||||
# Mock av.open for audio processing
|
||||
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||
# Mock container for checking video streams
|
||||
mock_container = MagicMock()
|
||||
mock_container.streams.video = [] # No video streams
|
||||
mock_container.close = MagicMock()
|
||||
|
||||
# Mock container for decoding audio frames
|
||||
mock_decode_container = MagicMock()
|
||||
mock_decode_container.decode.return_value = iter([MagicMock()])
|
||||
mock_decode_container.close = MagicMock()
|
||||
|
||||
# Return different containers for different calls
|
||||
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||
|
||||
# Run the pipeline
|
||||
await pipeline.process(upload_path)
|
||||
|
||||
# Verify the pipeline completed without diarization
|
||||
assert mock_storage._put_file.called
|
||||
assert mock_waveform_processor.flush.called
|
||||
assert mock_topic_detector.push.called
|
||||
assert mock_title_processor.push.called
|
||||
assert mock_summary_processor.push.called
|
||||
|
||||
print("PipelineMainFile no-diarization test completed successfully!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_pipeline_file_process(
|
||||
tmpdir,
|
||||
mock_transcript_in_db,
|
||||
dummy_file_transcript,
|
||||
dummy_file_diarization,
|
||||
mock_storage,
|
||||
mock_audio_file_writer,
|
||||
mock_waveform_processor,
|
||||
mock_topic_detector,
|
||||
mock_title_processor,
|
||||
mock_summary_processor,
|
||||
):
|
||||
"""
|
||||
Test the Celery task entry point for file pipeline processing.
|
||||
"""
|
||||
# Direct import of the underlying async function, bypassing the asynctask decorator
|
||||
|
||||
# Create a test audio file in the transcript's data path
|
||||
test_audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
upload_path = mock_transcript_in_db.data_path / "upload.wav"
|
||||
upload_path.write_bytes(test_audio_path.read_bytes())
|
||||
|
||||
# Also create the audio.mp3 file
|
||||
mp3_path = mock_transcript_in_db.data_path / "audio.mp3"
|
||||
mp3_path.write_bytes(b"mock_mp3_data")
|
||||
|
||||
# Mock av.open for audio processing
|
||||
with patch("reflector.pipelines.main_file_pipeline.av.open") as mock_av:
|
||||
# Mock container for checking video streams
|
||||
mock_container = MagicMock()
|
||||
mock_container.streams.video = [] # No video streams
|
||||
mock_container.close = MagicMock()
|
||||
|
||||
# Mock container for decoding audio frames
|
||||
mock_decode_container = MagicMock()
|
||||
mock_decode_container.decode.return_value = iter([MagicMock()])
|
||||
mock_decode_container.close = MagicMock()
|
||||
|
||||
# Return different containers for different calls
|
||||
mock_av.side_effect = [mock_container, mock_decode_container]
|
||||
|
||||
# Get the original async function without the asynctask decorator
|
||||
# The function is wrapped, so we need to call it differently
|
||||
# For now, we test the pipeline directly since the task is just a thin wrapper
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
await pipeline.process(upload_path)
|
||||
|
||||
# Verify the pipeline was executed through the task
|
||||
assert mock_audio_file_writer.push.called
|
||||
assert mock_audio_file_writer.flush.called
|
||||
assert mock_storage._put_file.called
|
||||
assert mock_waveform_processor.flush.called
|
||||
assert mock_topic_detector.push.called
|
||||
assert mock_title_processor.push.called
|
||||
assert mock_summary_processor.push.called
|
||||
|
||||
print("task_pipeline_file_process test completed successfully!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_file_process_no_transcript():
|
||||
"""
|
||||
Test the pipeline with a non-existent transcript.
|
||||
"""
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
|
||||
# Mock the controller to return None (transcript not found)
|
||||
with patch(
|
||||
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
||||
) as mock_get:
|
||||
mock_get.return_value = None
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=str(uuid4()))
|
||||
|
||||
# Should raise an exception for missing transcript when get_transcript is called
|
||||
with pytest.raises(Exception, match="Transcript not found"):
|
||||
await pipeline.get_transcript()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_file_process_no_audio_file(
|
||||
mock_transcript_in_db,
|
||||
):
|
||||
"""
|
||||
Test the pipeline when no audio file is found.
|
||||
"""
|
||||
from reflector.pipelines.main_file_pipeline import PipelineMainFile
|
||||
|
||||
# Don't create any audio files in the data path
|
||||
# The pipeline's process should handle missing files gracefully
|
||||
|
||||
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
|
||||
|
||||
# Try to process a non-existent file
|
||||
non_existent_path = mock_transcript_in_db.data_path / "nonexistent.wav"
|
||||
|
||||
# This should fail when trying to open the file with av
|
||||
with pytest.raises(Exception):
|
||||
await pipeline.process(non_existent_path)
|
||||
@@ -1,265 +0,0 @@
|
||||
"""
|
||||
Tests for Modal-based processors using pytest-recording for HTTP recording/playbook
|
||||
|
||||
Note: theses tests require full modal configuration to be able to record
|
||||
vcr cassettes
|
||||
|
||||
Configuration required for the first recording:
|
||||
- TRANSCRIPT_BACKEND=modal
|
||||
- TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-parakeet-web.modal.run
|
||||
- TRANSCRIPT_MODAL_API_KEY=xxxxx
|
||||
- DIARIZATION_BACKEND=modal
|
||||
- DIARIZATION_URL=https://xxxxx--reflector-diarizer-web.modal.run
|
||||
- DIARIZATION_MODAL_API_KEY=xxxxx
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.processors.file_diarization import FileDiarizationInput
|
||||
from reflector.processors.file_diarization_modal import FileDiarizationModalProcessor
|
||||
from reflector.processors.file_transcript import FileTranscriptInput
|
||||
from reflector.processors.file_transcript_modal import FileTranscriptModalProcessor
|
||||
from reflector.processors.transcript_diarization_assembler import (
|
||||
TranscriptDiarizationAssemblerInput,
|
||||
TranscriptDiarizationAssemblerProcessor,
|
||||
)
|
||||
from reflector.processors.types import DiarizationSegment, Transcript, Word
|
||||
|
||||
# Public test audio file hosted on S3 specifically for reflector pytests
|
||||
TEST_AUDIO_URL = (
|
||||
"https://reflector-github-pytest.s3.us-east-1.amazonaws.com/test_mathieu_hello.mp3"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_transcript_modal_processor_missing_url():
|
||||
with patch("reflector.processors.file_transcript_modal.settings") as mock_settings:
|
||||
mock_settings.TRANSCRIPT_URL = None
|
||||
with pytest.raises(Exception, match="TRANSCRIPT_URL required"):
|
||||
FileTranscriptModalProcessor(modal_api_key="test-api-key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_diarization_modal_processor_missing_url():
|
||||
with patch("reflector.processors.file_diarization_modal.settings") as mock_settings:
|
||||
mock_settings.DIARIZATION_URL = None
|
||||
with pytest.raises(Exception, match="DIARIZATION_URL required"):
|
||||
FileDiarizationModalProcessor(modal_api_key="test-api-key")
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_diarization_modal_processor(vcr):
|
||||
"""Test FileDiarizationModalProcessor using public audio URL and Modal API"""
|
||||
from reflector.settings import settings
|
||||
|
||||
processor = FileDiarizationModalProcessor(
|
||||
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
|
||||
)
|
||||
|
||||
test_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
|
||||
result = await processor._diarize(test_input)
|
||||
|
||||
# Verify the result structure
|
||||
assert result is not None
|
||||
assert hasattr(result, "diarization")
|
||||
assert isinstance(result.diarization, list)
|
||||
|
||||
# Check structure of each diarization segment
|
||||
for segment in result.diarization:
|
||||
assert "start" in segment
|
||||
assert "end" in segment
|
||||
assert "speaker" in segment
|
||||
assert isinstance(segment["start"], (int, float))
|
||||
assert isinstance(segment["end"], (int, float))
|
||||
assert isinstance(segment["speaker"], int)
|
||||
# Basic sanity check - start should be before end
|
||||
assert segment["start"] < segment["end"]
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_transcript_modal_processor():
|
||||
"""Test FileTranscriptModalProcessor using public audio URL and Modal API"""
|
||||
from reflector.settings import settings
|
||||
|
||||
processor = FileTranscriptModalProcessor(
|
||||
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
|
||||
)
|
||||
|
||||
test_input = FileTranscriptInput(
|
||||
audio_url=TEST_AUDIO_URL,
|
||||
language="en",
|
||||
)
|
||||
|
||||
# This will record the HTTP interaction on first run, replay on subsequent runs
|
||||
result = await processor._transcript(test_input)
|
||||
|
||||
# Verify the result structure
|
||||
assert result is not None
|
||||
assert hasattr(result, "words")
|
||||
assert isinstance(result.words, list)
|
||||
|
||||
# Check structure of each word if present
|
||||
for word in result.words:
|
||||
assert hasattr(word, "text")
|
||||
assert hasattr(word, "start")
|
||||
assert hasattr(word, "end")
|
||||
assert isinstance(word.start, (int, float))
|
||||
assert isinstance(word.end, (int, float))
|
||||
assert isinstance(word.text, str)
|
||||
# Basic sanity check - start should be before or equal to end
|
||||
assert word.start <= word.end
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_diarization_assembler_processor():
|
||||
"""Test TranscriptDiarizationAssemblerProcessor without VCR (no HTTP requests)"""
|
||||
# Create test transcript with words
|
||||
words = [
|
||||
Word(text="Hello", start=0.0, end=1.0, speaker=0),
|
||||
Word(text=" ", start=1.0, end=1.1, speaker=0),
|
||||
Word(text="world", start=1.1, end=2.0, speaker=0),
|
||||
Word(text=".", start=2.0, end=2.1, speaker=0),
|
||||
Word(text=" ", start=2.1, end=2.2, speaker=0),
|
||||
Word(text="How", start=2.2, end=2.8, speaker=0),
|
||||
Word(text=" ", start=2.8, end=2.9, speaker=0),
|
||||
Word(text="are", start=2.9, end=3.2, speaker=0),
|
||||
Word(text=" ", start=3.2, end=3.3, speaker=0),
|
||||
Word(text="you", start=3.3, end=3.8, speaker=0),
|
||||
Word(text="?", start=3.8, end=3.9, speaker=0),
|
||||
]
|
||||
transcript = Transcript(words=words)
|
||||
|
||||
# Create test diarization segments
|
||||
diarization = [
|
||||
DiarizationSegment(start=0.0, end=2.1, speaker=0),
|
||||
DiarizationSegment(start=2.1, end=3.9, speaker=1),
|
||||
]
|
||||
|
||||
# Create processor and test input
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
test_input = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript, diarization=diarization
|
||||
)
|
||||
|
||||
# Track emitted results
|
||||
emitted_results = []
|
||||
|
||||
async def capture_result(result):
|
||||
emitted_results.append(result)
|
||||
|
||||
processor.on(capture_result)
|
||||
|
||||
# Process the input
|
||||
await processor.push(test_input)
|
||||
|
||||
# Verify result was emitted
|
||||
assert len(emitted_results) == 1
|
||||
result = emitted_results[0]
|
||||
|
||||
# Verify result structure
|
||||
assert isinstance(result, Transcript)
|
||||
assert len(result.words) == len(words)
|
||||
|
||||
# Verify speaker assignments were applied
|
||||
# Words 0-3 (indices) should be speaker 0 (time 0.0-2.0)
|
||||
# Words 4-10 (indices) should be speaker 1 (time 2.1-3.9)
|
||||
for i in range(4): # First 4 words (Hello, space, world, .)
|
||||
assert (
|
||||
result.words[i].speaker == 0
|
||||
), f"Word {i} '{result.words[i].text}' should be speaker 0, got {result.words[i].speaker}"
|
||||
|
||||
for i in range(4, 11): # Remaining words (space, How, space, are, space, you, ?)
|
||||
assert (
|
||||
result.words[i].speaker == 1
|
||||
), f"Word {i} '{result.words[i].text}' should be speaker 1, got {result.words[i].speaker}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_diarization_assembler_no_diarization():
|
||||
"""Test TranscriptDiarizationAssemblerProcessor with no diarization data"""
|
||||
# Create test transcript
|
||||
words = [Word(text="Hello", start=0.0, end=1.0, speaker=0)]
|
||||
transcript = Transcript(words=words)
|
||||
|
||||
# Create processor and test input with empty diarization
|
||||
processor = TranscriptDiarizationAssemblerProcessor()
|
||||
test_input = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript, diarization=[]
|
||||
)
|
||||
|
||||
# Track emitted results
|
||||
emitted_results = []
|
||||
|
||||
async def capture_result(result):
|
||||
emitted_results.append(result)
|
||||
|
||||
processor.on(capture_result)
|
||||
|
||||
# Process the input
|
||||
await processor.push(test_input)
|
||||
|
||||
# Verify original transcript was returned unchanged
|
||||
assert len(emitted_results) == 1
|
||||
result = emitted_results[0]
|
||||
assert result is transcript # Should be the same object
|
||||
assert result.words[0].speaker == 0 # Original speaker unchanged
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_modal_pipeline_integration(vcr):
|
||||
"""Integration test: Transcription -> Diarization -> Assembly
|
||||
|
||||
This test demonstrates the full pipeline:
|
||||
1. Run transcription via Modal
|
||||
2. Run diarization via Modal
|
||||
3. Assemble transcript with diarization
|
||||
"""
|
||||
from reflector.settings import settings
|
||||
|
||||
# Step 1: Transcription
|
||||
transcript_processor = FileTranscriptModalProcessor(
|
||||
modal_api_key=settings.TRANSCRIPT_MODAL_API_KEY
|
||||
)
|
||||
transcript_input = FileTranscriptInput(audio_url=TEST_AUDIO_URL, language="en")
|
||||
transcript = await transcript_processor._transcript(transcript_input)
|
||||
|
||||
# Step 2: Diarization
|
||||
diarization_processor = FileDiarizationModalProcessor(
|
||||
modal_api_key=settings.DIARIZATION_MODAL_API_KEY
|
||||
)
|
||||
diarization_input = FileDiarizationInput(audio_url=TEST_AUDIO_URL)
|
||||
diarization_result = await diarization_processor._diarize(diarization_input)
|
||||
|
||||
# Step 3: Assembly
|
||||
assembler = TranscriptDiarizationAssemblerProcessor()
|
||||
assembly_input = TranscriptDiarizationAssemblerInput(
|
||||
transcript=transcript, diarization=diarization_result.diarization
|
||||
)
|
||||
|
||||
# Track assembled result
|
||||
assembled_results = []
|
||||
|
||||
async def capture_result(result):
|
||||
assembled_results.append(result)
|
||||
|
||||
assembler.on(capture_result)
|
||||
|
||||
await assembler.push(assembly_input)
|
||||
|
||||
# Verify the full pipeline worked
|
||||
assert len(assembled_results) == 1
|
||||
final_transcript = assembled_results[0]
|
||||
|
||||
# Verify the final transcript has the original words with updated speaker info
|
||||
assert isinstance(final_transcript, Transcript)
|
||||
assert len(final_transcript.words) == len(transcript.words)
|
||||
assert len(final_transcript.words) > 0
|
||||
|
||||
# Verify some words have been assigned speakers from diarization
|
||||
speakers_found = set(word.speaker for word in final_transcript.words)
|
||||
assert len(speakers_found) > 0 # At least some speaker assignments
|
||||
@@ -2,13 +2,10 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("enable_diarization", [False, True])
|
||||
async def test_basic_process(
|
||||
dummy_transcript,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
enable_diarization,
|
||||
dummy_diarization,
|
||||
):
|
||||
# goal is to start the server, and send rtc audio to it
|
||||
# validate the events received
|
||||
@@ -31,31 +28,12 @@ async def test_basic_process(
|
||||
|
||||
# invoke the process and capture events
|
||||
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||
|
||||
if enable_diarization:
|
||||
# Test with diarization - may fail if pyannote.audio is not installed
|
||||
try:
|
||||
await process_audio_file(
|
||||
path.as_posix(), event_callback, enable_diarization=True
|
||||
)
|
||||
except SystemExit:
|
||||
pytest.skip("pyannote.audio not installed - skipping diarization test")
|
||||
else:
|
||||
# Test without diarization - should always work
|
||||
await process_audio_file(
|
||||
path.as_posix(), event_callback, enable_diarization=False
|
||||
)
|
||||
|
||||
print(f"Diarization: {enable_diarization}, Marks: {marks}")
|
||||
await process_audio_file(path.as_posix(), event_callback)
|
||||
print(marks)
|
||||
|
||||
# validate the events
|
||||
# Each processor should be called for each audio segment processed
|
||||
# The final processors (Topic, Title, Summary) should be called once at the end
|
||||
assert marks["TranscriptLinerProcessor"] > 0
|
||||
assert marks["TranscriptTranslatorPassthroughProcessor"] > 0
|
||||
assert marks["TranscriptLinerProcessor"] == 1
|
||||
assert marks["TranscriptTranslatorPassthroughProcessor"] == 1
|
||||
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
||||
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
||||
assert marks["TranscriptFinalTitleProcessor"] == 1
|
||||
|
||||
if enable_diarization:
|
||||
assert marks["TestAudioDiarizationProcessor"] == 1
|
||||
|
||||
225
server/tests/test_room_ics.py
Normal file
225
server/tests/test_room_ics.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Tests for Room model ICS calendar integration fields.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db.rooms import rooms_controller
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_create_with_ics_fields():
|
||||
"""Test creating a room with ICS calendar fields."""
|
||||
room = await rooms_controller.add(
|
||||
name="test-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.google.com/calendar/ical/test/private-token/basic.ics",
|
||||
ics_fetch_interval=600,
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
assert room.name == "test-room"
|
||||
assert (
|
||||
room.ics_url
|
||||
== "https://calendar.google.com/calendar/ical/test/private-token/basic.ics"
|
||||
)
|
||||
assert room.ics_fetch_interval == 600
|
||||
assert room.ics_enabled is True
|
||||
assert room.ics_last_sync is None
|
||||
assert room.ics_last_etag is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_update_ics_configuration():
|
||||
"""Test updating room ICS configuration."""
|
||||
# Create room without ICS
|
||||
room = await rooms_controller.add(
|
||||
name="update-test",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
assert room.ics_enabled is False
|
||||
assert room.ics_url is None
|
||||
|
||||
# Update with ICS configuration
|
||||
await rooms_controller.update(
|
||||
room,
|
||||
{
|
||||
"ics_url": "https://outlook.office365.com/owa/calendar/test/calendar.ics",
|
||||
"ics_fetch_interval": 300,
|
||||
"ics_enabled": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert (
|
||||
room.ics_url == "https://outlook.office365.com/owa/calendar/test/calendar.ics"
|
||||
)
|
||||
assert room.ics_fetch_interval == 300
|
||||
assert room.ics_enabled is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_ics_sync_metadata():
|
||||
"""Test updating room ICS sync metadata."""
|
||||
room = await rooms_controller.add(
|
||||
name="sync-test",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://example.com/calendar.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
# Update sync metadata
|
||||
sync_time = datetime.now(timezone.utc)
|
||||
await rooms_controller.update(
|
||||
room,
|
||||
{
|
||||
"ics_last_sync": sync_time,
|
||||
"ics_last_etag": "abc123hash",
|
||||
},
|
||||
)
|
||||
|
||||
assert room.ics_last_sync == sync_time
|
||||
assert room.ics_last_etag == "abc123hash"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_get_with_ics_fields():
|
||||
"""Test retrieving room with ICS fields."""
|
||||
# Create room
|
||||
created_room = await rooms_controller.add(
|
||||
name="get-test",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="webcal://calendar.example.com/feed.ics",
|
||||
ics_fetch_interval=900,
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
# Get by ID
|
||||
room = await rooms_controller.get_by_id(created_room.id)
|
||||
assert room is not None
|
||||
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
|
||||
assert room.ics_fetch_interval == 900
|
||||
assert room.ics_enabled is True
|
||||
|
||||
# Get by name
|
||||
room = await rooms_controller.get_by_name("get-test")
|
||||
assert room is not None
|
||||
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
|
||||
assert room.ics_fetch_interval == 900
|
||||
assert room.ics_enabled is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_list_with_ics_enabled_filter():
|
||||
"""Test listing rooms filtered by ICS enabled status."""
|
||||
# Create rooms with and without ICS
|
||||
room1 = await rooms_controller.add(
|
||||
name="ics-enabled-1",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=True,
|
||||
ics_enabled=True,
|
||||
ics_url="https://calendar1.example.com/feed.ics",
|
||||
)
|
||||
|
||||
room2 = await rooms_controller.add(
|
||||
name="ics-disabled",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=True,
|
||||
ics_enabled=False,
|
||||
)
|
||||
|
||||
room3 = await rooms_controller.add(
|
||||
name="ics-enabled-2",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=True,
|
||||
ics_enabled=True,
|
||||
ics_url="https://calendar2.example.com/feed.ics",
|
||||
)
|
||||
|
||||
# Get all rooms
|
||||
all_rooms = await rooms_controller.get_all()
|
||||
assert len(all_rooms) == 3
|
||||
|
||||
# Filter for ICS-enabled rooms (would need to implement this in controller)
|
||||
ics_rooms = [r for r in all_rooms if r["ics_enabled"]]
|
||||
assert len(ics_rooms) == 2
|
||||
assert all(r["ics_enabled"] for r in ics_rooms)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_default_ics_values():
|
||||
"""Test that ICS fields have correct default values."""
|
||||
room = await rooms_controller.add(
|
||||
name="default-test",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
# Don't specify ICS fields
|
||||
)
|
||||
|
||||
assert room.ics_url is None
|
||||
assert room.ics_fetch_interval == 300 # Default 5 minutes
|
||||
assert room.ics_enabled is False
|
||||
assert room.ics_last_sync is None
|
||||
assert room.ics_last_etag is None
|
||||
385
server/tests/test_room_ics_api.py
Normal file
385
server/tests/test_room_ics_api.py
Normal file
@@ -0,0 +1,385 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from icalendar import Calendar, Event
|
||||
|
||||
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
|
||||
from reflector.db.rooms import rooms_controller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def authenticated_client(client):
|
||||
from reflector.app import app
|
||||
from reflector.auth import current_user_optional
|
||||
|
||||
app.dependency_overrides[current_user_optional] = lambda: {
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
yield client
|
||||
del app.dependency_overrides[current_user_optional]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_room_with_ics_fields(authenticated_client):
|
||||
client = authenticated_client
|
||||
response = await client.post(
|
||||
"/rooms",
|
||||
json={
|
||||
"name": "test-ics-room",
|
||||
"zulip_auto_post": False,
|
||||
"zulip_stream": "",
|
||||
"zulip_topic": "",
|
||||
"is_locked": False,
|
||||
"room_mode": "normal",
|
||||
"recording_type": "cloud",
|
||||
"recording_trigger": "automatic-2nd-participant",
|
||||
"is_shared": False,
|
||||
"ics_url": "https://calendar.example.com/test.ics",
|
||||
"ics_fetch_interval": 600,
|
||||
"ics_enabled": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "test-ics-room"
|
||||
assert data["ics_url"] == "https://calendar.example.com/test.ics"
|
||||
assert data["ics_fetch_interval"] == 600
|
||||
assert data["ics_enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_room_ics_configuration(authenticated_client):
|
||||
client = authenticated_client
|
||||
response = await client.post(
|
||||
"/rooms",
|
||||
json={
|
||||
"name": "update-ics-room",
|
||||
"zulip_auto_post": False,
|
||||
"zulip_stream": "",
|
||||
"zulip_topic": "",
|
||||
"is_locked": False,
|
||||
"room_mode": "normal",
|
||||
"recording_type": "cloud",
|
||||
"recording_trigger": "automatic-2nd-participant",
|
||||
"is_shared": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
room_id = response.json()["id"]
|
||||
|
||||
response = await client.patch(
|
||||
f"/rooms/{room_id}",
|
||||
json={
|
||||
"ics_url": "https://calendar.google.com/updated.ics",
|
||||
"ics_fetch_interval": 300,
|
||||
"ics_enabled": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["ics_url"] == "https://calendar.google.com/updated.ics"
|
||||
assert data["ics_fetch_interval"] == 300
|
||||
assert data["ics_enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_ics_sync(authenticated_client):
|
||||
client = authenticated_client
|
||||
room = await rooms_controller.add(
|
||||
name="sync-api-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/api.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
cal = Calendar()
|
||||
event = Event()
|
||||
event.add("uid", "api-test-event")
|
||||
event.add("summary", "API Test Meeting")
|
||||
from reflector.settings import settings
|
||||
|
||||
event.add("location", f"{settings.BASE_URL}/room/{room.name}")
|
||||
now = datetime.now(timezone.utc)
|
||||
event.add("dtstart", now + timedelta(hours=1))
|
||||
event.add("dtend", now + timedelta(hours=2))
|
||||
cal.add_component(event)
|
||||
ics_content = cal.to_ical().decode("utf-8")
|
||||
|
||||
with patch(
|
||||
"reflector.services.ics_sync.ICSFetchService.fetch_ics", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = ics_content
|
||||
|
||||
response = await client.post(f"/rooms/{room.name}/ics/sync")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "success"
|
||||
assert data["events_found"] == 1
|
||||
assert data["events_created"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_ics_sync_unauthorized(client):
|
||||
room = await rooms_controller.add(
|
||||
name="sync-unauth-room",
|
||||
user_id="owner-123",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/api.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
response = await client.post(f"/rooms/{room.name}/ics/sync")
|
||||
assert response.status_code == 403
|
||||
assert "Only room owner can trigger ICS sync" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_ics_sync_not_configured(authenticated_client):
|
||||
client = authenticated_client
|
||||
room = await rooms_controller.add(
|
||||
name="sync-not-configured",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_enabled=False,
|
||||
)
|
||||
|
||||
response = await client.post(f"/rooms/{room.name}/ics/sync")
|
||||
assert response.status_code == 400
|
||||
assert "ICS not configured" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ics_status(authenticated_client):
|
||||
client = authenticated_client
|
||||
room = await rooms_controller.add(
|
||||
name="status-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/status.ics",
|
||||
ics_enabled=True,
|
||||
ics_fetch_interval=300,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
await rooms_controller.update(
|
||||
room,
|
||||
{"ics_last_sync": now, "ics_last_etag": "test-etag"},
|
||||
)
|
||||
|
||||
response = await client.get(f"/rooms/{room.name}/ics/status")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "enabled"
|
||||
assert data["last_etag"] == "test-etag"
|
||||
assert data["events_count"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ics_status_unauthorized(client):
|
||||
room = await rooms_controller.add(
|
||||
name="status-unauth",
|
||||
user_id="owner-456",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
ics_url="https://calendar.example.com/status.ics",
|
||||
ics_enabled=True,
|
||||
)
|
||||
|
||||
response = await client.get(f"/rooms/{room.name}/ics/status")
|
||||
assert response.status_code == 403
|
||||
assert "Only room owner can view ICS status" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_room_meetings(authenticated_client):
|
||||
client = authenticated_client
|
||||
room = await rooms_controller.add(
|
||||
name="meetings-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
event1 = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="meeting-1",
|
||||
title="Past Meeting",
|
||||
start_time=now - timedelta(hours=2),
|
||||
end_time=now - timedelta(hours=1),
|
||||
)
|
||||
await calendar_events_controller.upsert(event1)
|
||||
|
||||
event2 = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="meeting-2",
|
||||
title="Future Meeting",
|
||||
description="Team sync",
|
||||
start_time=now + timedelta(hours=1),
|
||||
end_time=now + timedelta(hours=2),
|
||||
attendees=[{"email": "test@example.com"}],
|
||||
)
|
||||
await calendar_events_controller.upsert(event2)
|
||||
|
||||
response = await client.get(f"/rooms/{room.name}/meetings")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["title"] == "Past Meeting"
|
||||
assert data[1]["title"] == "Future Meeting"
|
||||
assert data[1]["description"] == "Team sync"
|
||||
assert data[1]["attendees"] == [{"email": "test@example.com"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_room_meetings_non_owner(client):
|
||||
room = await rooms_controller.add(
|
||||
name="meetings-privacy",
|
||||
user_id="owner-789",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="private-meeting",
|
||||
title="Meeting Title",
|
||||
description="Sensitive info",
|
||||
start_time=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
end_time=datetime.now(timezone.utc) + timedelta(hours=2),
|
||||
attendees=[{"email": "private@example.com"}],
|
||||
)
|
||||
await calendar_events_controller.upsert(event)
|
||||
|
||||
response = await client.get(f"/rooms/{room.name}/meetings")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["title"] == "Meeting Title"
|
||||
assert data[0]["description"] is None
|
||||
assert data[0]["attendees"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_upcoming_meetings(authenticated_client):
|
||||
client = authenticated_client
|
||||
room = await rooms_controller.add(
|
||||
name="upcoming-room",
|
||||
user_id="test-user",
|
||||
zulip_auto_post=False,
|
||||
zulip_stream="",
|
||||
zulip_topic="",
|
||||
is_locked=False,
|
||||
room_mode="normal",
|
||||
recording_type="cloud",
|
||||
recording_trigger="automatic-2nd-participant",
|
||||
is_shared=False,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
past_event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="past",
|
||||
title="Past",
|
||||
start_time=now - timedelta(hours=1),
|
||||
end_time=now - timedelta(minutes=30),
|
||||
)
|
||||
await calendar_events_controller.upsert(past_event)
|
||||
|
||||
soon_event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="soon",
|
||||
title="Soon",
|
||||
start_time=now + timedelta(minutes=15),
|
||||
end_time=now + timedelta(minutes=45),
|
||||
)
|
||||
await calendar_events_controller.upsert(soon_event)
|
||||
|
||||
later_event = CalendarEvent(
|
||||
room_id=room.id,
|
||||
ics_uid="later",
|
||||
title="Later",
|
||||
start_time=now + timedelta(hours=2),
|
||||
end_time=now + timedelta(hours=3),
|
||||
)
|
||||
await calendar_events_controller.upsert(later_event)
|
||||
|
||||
response = await client.get(f"/rooms/{room.name}/meetings/upcoming")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["title"] == "Soon"
|
||||
|
||||
response = await client.get(
|
||||
f"/rooms/{room.name}/meetings/upcoming", params={"minutes_ahead": 180}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["title"] == "Soon"
|
||||
assert data[1]["title"] == "Later"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_not_found_endpoints(client):
|
||||
response = await client.post("/rooms/nonexistent/ics/sync")
|
||||
assert response.status_code == 404
|
||||
|
||||
response = await client.get("/rooms/nonexistent/ics/status")
|
||||
assert response.status_code == 404
|
||||
|
||||
response = await client.get("/rooms/nonexistent/meetings")
|
||||
assert response.status_code == 404
|
||||
|
||||
response = await client.get("/rooms/nonexistent/meetings/upcoming")
|
||||
assert response.status_code == 404
|
||||
@@ -2,18 +2,13 @@
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.search import (
|
||||
SearchController,
|
||||
SearchParameters,
|
||||
SearchResult,
|
||||
search_controller,
|
||||
)
|
||||
from reflector.db.transcripts import SourceKind, transcripts
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -23,137 +18,39 @@ async def test_search_postgresql_only():
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
params_empty = SearchParameters(query_text="")
|
||||
results_empty, total_empty = await search_controller.search_transcripts(
|
||||
params_empty
|
||||
)
|
||||
assert isinstance(results_empty, list)
|
||||
assert isinstance(total_empty, int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_empty_query():
|
||||
"""Test that empty query returns all transcripts."""
|
||||
params = SearchParameters(query_text="")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(total, int)
|
||||
if len(results) > 1:
|
||||
for i in range(len(results) - 1):
|
||||
assert results[i].created_at >= results[i + 1].created_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_title_only_match():
|
||||
"""Test that transcripts with title-only matches return empty snippets."""
|
||||
test_id = "test-empty-9b3f2a8d"
|
||||
|
||||
try:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Empty Transcript",
|
||||
"title": "Empty Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 0.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": None,
|
||||
"long_summary": None,
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": None,
|
||||
"user_id": "test-user-1",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
params = SearchParameters(query_text="empty", user_id="test-user-1")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
found = next((r for r in results if r.id == test_id), None)
|
||||
assert found is not None, "Should find transcript by title match"
|
||||
assert found.search_snippets == []
|
||||
assert found.total_match_count == 0
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" ")
|
||||
assert False, "Should have raised validation error"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_long_summary():
|
||||
"""Test that long_summary content is searchable."""
|
||||
test_id = "test-long-summary-8a9f3c2d"
|
||||
|
||||
async def test_search_input_validation():
|
||||
try:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
SearchParameters(query_text="")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Long Summary",
|
||||
"title": "Regular Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Brief overview",
|
||||
"long_summary": "Detailed discussion about quantum computing applications and blockchain technology integration",
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Basic meeting content without special keywords.""",
|
||||
"user_id": "test-user-2",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
params = SearchParameters(query_text="quantum computing", user_id="test-user-2")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find transcript by long_summary content"
|
||||
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result
|
||||
assert len(test_result.search_snippets) > 0
|
||||
assert "quantum computing" in test_result.search_snippets[0].lower()
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
# Test that whitespace query raises validation error
|
||||
try:
|
||||
SearchParameters(query_text=" \t\n ")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgresql_search_with_data():
|
||||
# collision is improbable
|
||||
test_id = "test-search-e2e-7f3a9b2c"
|
||||
|
||||
try:
|
||||
@@ -193,31 +90,32 @@ The search feature should support complex queries with ranking.
|
||||
|
||||
00:00:30.000 --> 00:00:40.000
|
||||
We need to implement PostgreSQL tsvector for better performance.""",
|
||||
"user_id": "test-user-3",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
params = SearchParameters(query_text="planning", user_id="test-user-3")
|
||||
# Test 1: Search for a word in title
|
||||
params = SearchParameters(query_text="planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by title word"
|
||||
|
||||
params = SearchParameters(query_text="tsvector", user_id="test-user-3")
|
||||
# Test 2: Search for a word in webvtt content
|
||||
params = SearchParameters(query_text="tsvector")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by webvtt content"
|
||||
|
||||
params = SearchParameters(
|
||||
query_text="engineering planning", user_id="test-user-3"
|
||||
)
|
||||
# Test 3: Search with multiple words
|
||||
params = SearchParameters(query_text="engineering planning")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript by multiple words"
|
||||
|
||||
# Test 4: Verify SearchResult structure
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
if test_result:
|
||||
assert test_result.title == "Engineering Planning Meeting Q4 2024"
|
||||
@@ -225,17 +123,15 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
assert test_result.duration == 1800.0
|
||||
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
|
||||
|
||||
params = SearchParameters(
|
||||
query_text="tsvector OR nosuchword", user_id="test-user-3"
|
||||
)
|
||||
# Test 5: Search with OR operator
|
||||
params = SearchParameters(query_text="tsvector OR nosuchword")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find test transcript with OR query"
|
||||
|
||||
params = SearchParameters(
|
||||
query_text='"full-text search"', user_id="test-user-3"
|
||||
)
|
||||
# Test 6: Quoted phrase search
|
||||
params = SearchParameters(query_text='"full-text search"')
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
assert total >= 1
|
||||
found = any(r.id == test_id for r in results)
|
||||
@@ -246,240 +142,3 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_params():
|
||||
"""Create sample search parameters for testing."""
|
||||
return SearchParameters(
|
||||
query_text="test query",
|
||||
limit=20,
|
||||
offset=0,
|
||||
user_id="test-user",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_result():
|
||||
"""Create a mock database result."""
|
||||
return {
|
||||
"id": "test-transcript-id",
|
||||
"title": "Test Transcript",
|
||||
"created_at": datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||
"duration": 3600.0,
|
||||
"status": "completed",
|
||||
"user_id": "test-user",
|
||||
"room_id": "room1",
|
||||
"source_kind": SourceKind.LIVE,
|
||||
"webvtt": "WEBVTT\n\n00:00:00.000 --> 00:00:05.000\nThis is a test transcript",
|
||||
"rank": 0.95,
|
||||
}
|
||||
|
||||
|
||||
class TestSearchParameters:
|
||||
"""Test SearchParameters model validation and functionality."""
|
||||
|
||||
def test_search_parameters_with_available_filters(self):
|
||||
"""Test creating SearchParameters with currently available filter options."""
|
||||
params = SearchParameters(
|
||||
query_text="search term",
|
||||
limit=50,
|
||||
offset=10,
|
||||
user_id="user123",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
assert params.query_text == "search term"
|
||||
assert params.limit == 50
|
||||
assert params.offset == 10
|
||||
assert params.user_id == "user123"
|
||||
assert params.room_id == "room1"
|
||||
|
||||
def test_search_parameters_defaults(self):
|
||||
"""Test SearchParameters with default values."""
|
||||
params = SearchParameters(query_text="test")
|
||||
|
||||
assert params.query_text == "test"
|
||||
assert params.limit == 20
|
||||
assert params.offset == 0
|
||||
assert params.user_id is None
|
||||
assert params.room_id is None
|
||||
|
||||
|
||||
class TestSearchControllerFilters:
|
||||
"""Test SearchController functionality with various filters."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_source_kind_filter(self):
|
||||
"""Test search filtering by source_kind."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||
|
||||
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
mock_db.return_value.fetch_all.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_single_room_id(self):
|
||||
"""Test search filtering by single room ID (currently supported)."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||
|
||||
params = SearchParameters(
|
||||
query_text="test",
|
||||
room_id="room1",
|
||||
)
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
mock_db.return_value.fetch_all.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_result_includes_available_fields(self, mock_db_result):
|
||||
"""Test that search results include available fields like source_kind."""
|
||||
controller = SearchController()
|
||||
with (
|
||||
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||
patch("reflector.db.search.get_database") as mock_db,
|
||||
):
|
||||
|
||||
class MockRow:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
self._mapping = data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.items())
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
def keys(self):
|
||||
return self._data.keys()
|
||||
|
||||
mock_row = MockRow(mock_db_result)
|
||||
|
||||
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
|
||||
|
||||
params = SearchParameters(query_text="test")
|
||||
|
||||
results, total = await controller.search_transcripts(params)
|
||||
|
||||
assert total == 1
|
||||
assert len(results) == 1
|
||||
|
||||
result = results[0]
|
||||
assert isinstance(result, SearchResult)
|
||||
assert result.id == "test-transcript-id"
|
||||
assert result.title == "Test Transcript"
|
||||
assert result.rank == 0.95
|
||||
|
||||
|
||||
class TestSearchEndpointParsing:
|
||||
"""Test parameter parsing in the search endpoint."""
|
||||
|
||||
def test_parse_comma_separated_room_ids(self):
|
||||
"""Test parsing comma-separated room IDs."""
|
||||
room_ids_str = "room1,room2,room3"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room2", "room3"]
|
||||
|
||||
room_ids_str = "room1, room2 , room3"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room2", "room3"]
|
||||
|
||||
room_ids_str = "room1,,room3,"
|
||||
parsed = [rid.strip() for rid in room_ids_str.split(",") if rid.strip()]
|
||||
assert parsed == ["room1", "room3"]
|
||||
|
||||
def test_parse_source_kind(self):
|
||||
"""Test parsing source_kind values."""
|
||||
for kind_str in ["live", "file", "room"]:
|
||||
parsed = SourceKind(kind_str)
|
||||
assert parsed == SourceKind(kind_str)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SourceKind("invalid_kind")
|
||||
|
||||
|
||||
class TestSearchResultModel:
|
||||
"""Test SearchResult model and serialization."""
|
||||
|
||||
def test_search_result_with_available_fields(self):
|
||||
"""Test SearchResult model with currently available fields populated."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
title="Test Title",
|
||||
user_id="user-123",
|
||||
room_id="room-456",
|
||||
source_kind=SourceKind.ROOM,
|
||||
created_at=datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||
status="completed",
|
||||
rank=0.85,
|
||||
duration=1800.5,
|
||||
search_snippets=["snippet 1", "snippet 2"],
|
||||
)
|
||||
|
||||
assert result.id == "test-id"
|
||||
assert result.title == "Test Title"
|
||||
assert result.user_id == "user-123"
|
||||
assert result.room_id == "room-456"
|
||||
assert result.status == "completed"
|
||||
assert result.rank == 0.85
|
||||
assert result.duration == 1800.5
|
||||
assert len(result.search_snippets) == 2
|
||||
|
||||
def test_search_result_with_optional_fields_none(self):
|
||||
"""Test SearchResult model with optional fields as None."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
source_kind=SourceKind.FILE,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
status="processing",
|
||||
rank=0.5,
|
||||
search_snippets=[],
|
||||
title=None,
|
||||
user_id=None,
|
||||
room_id=None,
|
||||
duration=None,
|
||||
)
|
||||
|
||||
assert result.title is None
|
||||
assert result.user_id is None
|
||||
assert result.room_id is None
|
||||
assert result.duration is None
|
||||
|
||||
def test_search_result_datetime_field(self):
|
||||
"""Test that SearchResult accepts datetime field."""
|
||||
result = SearchResult(
|
||||
id="test-id",
|
||||
source_kind=SourceKind.LIVE,
|
||||
created_at=datetime(2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc),
|
||||
status="completed",
|
||||
rank=0.9,
|
||||
duration=None,
|
||||
search_snippets=[],
|
||||
)
|
||||
|
||||
assert result.created_at == datetime(
|
||||
2024, 6, 15, 12, 30, 45, tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
"""Tests for long_summary in search functionality."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from reflector.db import get_database
|
||||
from reflector.db.search import SearchParameters, search_controller
|
||||
from reflector.db.transcripts import transcripts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_summary_snippet_prioritization():
|
||||
"""Test that snippets from long_summary are prioritized over webvtt content."""
|
||||
test_id = "test-snippet-priority-3f9a2b8c"
|
||||
|
||||
try:
|
||||
# Clean up any existing test data
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Snippet Priority",
|
||||
"title": "Meeting About Projects",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Project discussion",
|
||||
"long_summary": (
|
||||
"The team discussed advanced robotics applications including "
|
||||
"autonomous navigation systems and sensor fusion techniques. "
|
||||
"Robotics development will focus on real-time processing."
|
||||
),
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
We talked about many different topics today.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
The robotics project is making good progress.
|
||||
|
||||
00:00:20.000 --> 00:00:30.000
|
||||
We need to consider various implementation approaches.""",
|
||||
"user_id": "test-user-priority",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Search for "robotics" which appears in both long_summary and webvtt
|
||||
params = SearchParameters(query_text="robotics", user_id="test-user-priority")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
assert total >= 1
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result, "Should find the test transcript"
|
||||
|
||||
snippets = test_result.search_snippets
|
||||
assert len(snippets) > 0, "Should have at least one snippet"
|
||||
|
||||
# The first snippets should be from long_summary (more detailed content)
|
||||
first_snippet = snippets[0].lower()
|
||||
assert (
|
||||
"advanced robotics" in first_snippet or "autonomous" in first_snippet
|
||||
), f"First snippet should be from long_summary with detailed content. Got: {snippets[0]}"
|
||||
|
||||
# With max 3 snippets, we should get both from long_summary and webvtt
|
||||
assert len(snippets) <= 3, "Should respect max snippets limit"
|
||||
|
||||
# All snippets should contain the search term
|
||||
for snippet in snippets:
|
||||
assert (
|
||||
"robotics" in snippet.lower()
|
||||
), f"Snippet should contain search term: {snippet}"
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_summary_only_search():
|
||||
"""Test searching for content that only exists in long_summary."""
|
||||
test_id = "test-long-only-8b3c9f2a"
|
||||
|
||||
try:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
|
||||
test_data = {
|
||||
"id": test_id,
|
||||
"name": "Test Long Only",
|
||||
"title": "Standard Meeting",
|
||||
"status": "completed",
|
||||
"locked": False,
|
||||
"duration": 1800.0,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"short_summary": "Team sync",
|
||||
"long_summary": (
|
||||
"Detailed analysis of cryptocurrency market trends and "
|
||||
"decentralized finance protocols. Discussion included "
|
||||
"yield farming strategies and liquidity pool mechanics."
|
||||
),
|
||||
"topics": json.dumps([]),
|
||||
"events": json.dumps([]),
|
||||
"participants": json.dumps([]),
|
||||
"source_language": "en",
|
||||
"target_language": "en",
|
||||
"reviewed": False,
|
||||
"audio_location": "local",
|
||||
"share_mode": "private",
|
||||
"source_kind": "room",
|
||||
"webvtt": """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:10.000
|
||||
Team meeting about general project updates.
|
||||
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
Discussion of timeline and deliverables.""",
|
||||
"user_id": "test-user-long",
|
||||
}
|
||||
|
||||
await get_database().execute(transcripts.insert().values(**test_data))
|
||||
|
||||
# Search for terms only in long_summary
|
||||
params = SearchParameters(query_text="cryptocurrency", user_id="test-user-long")
|
||||
results, total = await search_controller.search_transcripts(params)
|
||||
|
||||
found = any(r.id == test_id for r in results)
|
||||
assert found, "Should find transcript by long_summary-only content"
|
||||
|
||||
test_result = next((r for r in results if r.id == test_id), None)
|
||||
assert test_result
|
||||
assert len(test_result.search_snippets) > 0
|
||||
|
||||
# Verify the snippet is about cryptocurrency
|
||||
snippet = test_result.search_snippets[0].lower()
|
||||
assert "cryptocurrency" in snippet, "Snippet should contain the search term"
|
||||
|
||||
# Search for "yield farming" - a more specific term
|
||||
params2 = SearchParameters(query_text="yield farming", user_id="test-user-long")
|
||||
results2, total2 = await search_controller.search_transcripts(params2)
|
||||
|
||||
found2 = any(r.id == test_id for r in results2)
|
||||
assert found2, "Should find transcript by specific long_summary phrase"
|
||||
|
||||
finally:
|
||||
await get_database().execute(
|
||||
transcripts.delete().where(transcripts.c.id == test_id)
|
||||
)
|
||||
await get_database().disconnect()
|
||||
@@ -1,10 +1,6 @@
|
||||
"""Unit tests for search snippet generation."""
|
||||
|
||||
from reflector.db.search import (
|
||||
SnippetCandidate,
|
||||
SnippetGenerator,
|
||||
WebVTTProcessor,
|
||||
)
|
||||
from reflector.db.search import SearchController
|
||||
|
||||
|
||||
class TestExtractWebVTT:
|
||||
@@ -20,7 +16,7 @@ class TestExtractWebVTT:
|
||||
00:00:10.000 --> 00:00:20.000
|
||||
<v Speaker1>Indeed it is a test of WebVTT parsing.
|
||||
"""
|
||||
result = WebVTTProcessor.extract_text(webvtt)
|
||||
result = SearchController._extract_webvtt_text(webvtt)
|
||||
assert "Hello world, this is a test" in result
|
||||
assert "Indeed it is a test" in result
|
||||
assert "<v Speaker" not in result
|
||||
@@ -29,11 +25,12 @@ class TestExtractWebVTT:
|
||||
|
||||
def test_extract_empty_webvtt(self):
|
||||
"""Test empty WebVTT returns empty string."""
|
||||
assert WebVTTProcessor.extract_text("") == ""
|
||||
assert SearchController._extract_webvtt_text("") == ""
|
||||
assert SearchController._extract_webvtt_text(None) == ""
|
||||
|
||||
def test_extract_malformed_webvtt(self):
|
||||
"""Test malformed WebVTT returns empty string."""
|
||||
result = WebVTTProcessor.extract_text("Not a valid WebVTT")
|
||||
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
|
||||
assert result == ""
|
||||
|
||||
|
||||
@@ -42,7 +39,8 @@ class TestGenerateSnippets:
|
||||
|
||||
def test_multiple_matches(self):
|
||||
"""Test finding multiple occurrences of search term in long text."""
|
||||
separator = " This is filler text. " * 20
|
||||
# Create text with Python mentions far apart to get separate snippets
|
||||
separator = " This is filler text. " * 20 # ~400 chars of padding
|
||||
text = (
|
||||
"Python is great for machine learning."
|
||||
+ separator
|
||||
@@ -53,16 +51,18 @@ class TestGenerateSnippets:
|
||||
+ "The Python community is very supportive."
|
||||
)
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "Python")
|
||||
assert len(snippets) >= 2
|
||||
snippets = SearchController._generate_snippets(text, "Python")
|
||||
# With enough separation, we should get multiple snippets
|
||||
assert len(snippets) >= 2 # At least 2 distinct snippets
|
||||
|
||||
# Each snippet should contain "Python"
|
||||
for snippet in snippets:
|
||||
assert "python" in snippet.lower()
|
||||
|
||||
def test_single_match(self):
|
||||
"""Test single occurrence returns one snippet."""
|
||||
text = "This document discusses artificial intelligence and its applications."
|
||||
snippets = SnippetGenerator.generate(text, "artificial intelligence")
|
||||
snippets = SearchController._generate_snippets(text, "artificial intelligence")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "artificial intelligence" in snippets[0].lower()
|
||||
@@ -70,22 +70,24 @@ class TestGenerateSnippets:
|
||||
def test_no_matches(self):
|
||||
"""Test no matches returns empty list."""
|
||||
text = "This is some random text without the search term."
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
|
||||
assert snippets == []
|
||||
|
||||
def test_case_insensitive_search(self):
|
||||
"""Test search is case insensitive."""
|
||||
# Add enough text between matches to get separate snippets
|
||||
text = (
|
||||
"MACHINE LEARNING is important for modern applications. "
|
||||
+ "It requires lots of data and computational resources. " * 5
|
||||
+ "It requires lots of data and computational resources. " * 5 # Padding
|
||||
+ "Machine Learning rocks and transforms industries. "
|
||||
+ "Deep learning is a subset of it. " * 5
|
||||
+ "Deep learning is a subset of it. " * 5 # More padding
|
||||
+ "Finally, machine learning will shape our future."
|
||||
)
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
|
||||
# Should find at least 2 (might be 3 if text is long enough)
|
||||
assert len(snippets) >= 2
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
@@ -93,55 +95,61 @@ class TestGenerateSnippets:
|
||||
def test_partial_match_fallback(self):
|
||||
"""Test fallback to first word when exact phrase not found."""
|
||||
text = "We use machine intelligence for processing."
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
snippets = SearchController._generate_snippets(text, "machine learning")
|
||||
|
||||
# Should fall back to finding "machine"
|
||||
assert len(snippets) == 1
|
||||
assert "machine" in snippets[0].lower()
|
||||
|
||||
def test_snippet_ellipsis(self):
|
||||
"""Test ellipsis added for truncated snippets."""
|
||||
# Long text where match is in the middle
|
||||
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
|
||||
snippets = SnippetGenerator.generate(text, "TARGET_WORD")
|
||||
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
|
||||
|
||||
assert len(snippets) == 1
|
||||
assert "..." in snippets[0]
|
||||
assert "..." in snippets[0] # Should have ellipsis
|
||||
assert "TARGET_WORD" in snippets[0]
|
||||
|
||||
def test_overlapping_snippets_deduplicated(self):
|
||||
"""Test overlapping matches don't create duplicate snippets."""
|
||||
text = "test test test word" * 10
|
||||
snippets = SnippetGenerator.generate(text, "test")
|
||||
text = "test test test word" * 10 # Repeated pattern
|
||||
snippets = SearchController._generate_snippets(text, "test")
|
||||
|
||||
# Should get unique snippets, not duplicates
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) == len(set(snippets))
|
||||
assert len(snippets) == len(set(snippets)) # All unique
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test empty text or search term returns empty list."""
|
||||
assert SnippetGenerator.generate("", "search") == []
|
||||
assert SnippetGenerator.generate("text", "") == []
|
||||
assert SnippetGenerator.generate("", "") == []
|
||||
assert SearchController._generate_snippets("", "search") == []
|
||||
assert SearchController._generate_snippets("text", "") == []
|
||||
assert SearchController._generate_snippets("", "") == []
|
||||
|
||||
def test_max_snippets_limit(self):
|
||||
"""Test respects max_snippets parameter."""
|
||||
separator = " filler " * 50
|
||||
text = ("Python is amazing" + separator) * 10
|
||||
# Create text with well-separated occurrences
|
||||
separator = " filler " * 50 # Ensure snippets don't overlap
|
||||
text = ("Python is amazing" + separator) * 10 # 10 occurrences
|
||||
|
||||
snippets_1 = SnippetGenerator.generate(text, "Python", max_snippets=1)
|
||||
# Test with different limits
|
||||
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
|
||||
assert len(snippets_1) == 1
|
||||
|
||||
snippets_2 = SnippetGenerator.generate(text, "Python", max_snippets=2)
|
||||
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
|
||||
assert len(snippets_2) == 2
|
||||
|
||||
snippets_5 = SnippetGenerator.generate(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5
|
||||
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
|
||||
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
|
||||
|
||||
def test_snippet_length(self):
|
||||
"""Test snippet length is reasonable."""
|
||||
text = "word " * 200
|
||||
snippets = SnippetGenerator.generate(text, "word")
|
||||
text = "word " * 200 # Long text
|
||||
snippets = SearchController._generate_snippets(text, "word")
|
||||
|
||||
for snippet in snippets:
|
||||
assert len(snippet) <= 200
|
||||
# Default max_length is 150 + some context
|
||||
assert len(snippet) <= 200 # Some buffer for ellipsis
|
||||
|
||||
|
||||
class TestFullPipeline:
|
||||
@@ -149,6 +157,7 @@ class TestFullPipeline:
|
||||
|
||||
def test_webvtt_to_snippets_integration(self):
|
||||
"""Test full pipeline from WebVTT to search snippets."""
|
||||
# Create WebVTT with well-separated content for multiple snippets
|
||||
webvtt = (
|
||||
"""WEBVTT
|
||||
|
||||
@@ -173,362 +182,17 @@ class TestFullPipeline:
|
||||
"""
|
||||
)
|
||||
|
||||
plain_text = WebVTTProcessor.extract_text(webvtt)
|
||||
snippets = SnippetGenerator.generate(plain_text, "machine learning")
|
||||
# Extract and generate snippets
|
||||
plain_text = SearchController._extract_webvtt_text(webvtt)
|
||||
snippets = SearchController._generate_snippets(plain_text, "machine learning")
|
||||
|
||||
assert len(snippets) >= 1
|
||||
assert len(snippets) <= 3
|
||||
# Should find at least 2 snippets (text might still be close together)
|
||||
assert len(snippets) >= 1 # At minimum one snippet containing matches
|
||||
assert len(snippets) <= 3 # At most 3 by default
|
||||
|
||||
# No WebVTT artifacts in snippets
|
||||
for snippet in snippets:
|
||||
assert "machine learning" in snippet.lower()
|
||||
assert "<v Speaker" not in snippet
|
||||
assert "00:00" not in snippet
|
||||
assert "-->" not in snippet
|
||||
|
||||
|
||||
class TestMultiWordQueryBehavior:
|
||||
"""Tests for multi-word query behavior and exact phrase matching."""
|
||||
|
||||
def test_multi_word_query_snippet_behavior(self):
|
||||
"""Test that multi-word queries generate snippets based on exact phrase matching."""
|
||||
sample_text = """This is a sample transcript where user Alice is talking.
|
||||
Later in the conversation, jordan mentions something important.
|
||||
The user jordan collaboration was successful.
|
||||
Another user named Bob joins the discussion."""
|
||||
|
||||
user_snippets = SnippetGenerator.generate(sample_text, "user")
|
||||
assert len(user_snippets) == 2, "Should find 2 snippets for 'user'"
|
||||
|
||||
jordan_snippets = SnippetGenerator.generate(sample_text, "jordan")
|
||||
assert len(jordan_snippets) >= 1, "Should find at least 1 snippet for 'jordan'"
|
||||
|
||||
multi_word_snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||
assert len(multi_word_snippets) == 1, (
|
||||
"Should return exactly 1 snippet for 'user jordan' "
|
||||
"(only the exact phrase match, not individual word occurrences)"
|
||||
)
|
||||
|
||||
snippet = multi_word_snippets[0]
|
||||
assert (
|
||||
"user jordan" in snippet.lower()
|
||||
), "The snippet should contain the exact phrase 'user jordan'"
|
||||
|
||||
assert (
|
||||
"alice" not in snippet.lower()
|
||||
), "The snippet should not include the first standalone 'user' with Alice"
|
||||
|
||||
def test_multi_word_query_without_exact_match(self):
|
||||
"""Test snippet generation when exact phrase is not found."""
|
||||
sample_text = """User Alice is here. Bob and jordan are talking.
|
||||
Later jordan mentions something. The user is happy."""
|
||||
|
||||
snippets = SnippetGenerator.generate(sample_text, "user jordan")
|
||||
|
||||
assert (
|
||||
len(snippets) >= 1
|
||||
), "Should find at least 1 snippet when falling back to first word"
|
||||
|
||||
all_snippets_text = " ".join(snippets).lower()
|
||||
assert (
|
||||
"user" in all_snippets_text
|
||||
), "Snippets should contain 'user' (the first word)"
|
||||
|
||||
def test_exact_phrase_at_text_boundaries(self):
|
||||
"""Test snippet generation when exact phrase appears at text boundaries."""
|
||||
|
||||
text_start = "user jordan started the meeting. Other content here."
|
||||
snippets = SnippetGenerator.generate(text_start, "user jordan")
|
||||
assert len(snippets) == 1
|
||||
assert "user jordan" in snippets[0].lower()
|
||||
|
||||
text_end = "Other content here. The meeting ended with user jordan"
|
||||
snippets = SnippetGenerator.generate(text_end, "user jordan")
|
||||
assert len(snippets) == 1
|
||||
assert "user jordan" in snippets[0].lower()
|
||||
|
||||
def test_multi_word_query_matches_words_appearing_separately_and_together(self):
|
||||
"""Test that multi-word queries prioritize exact phrase matches over individual word occurrences."""
|
||||
sample_text = """This is a sample transcript where user Alice is talking.
|
||||
Later in the conversation, jordan mentions something important.
|
||||
The user jordan collaboration was successful.
|
||||
Another user named Bob joins the discussion."""
|
||||
|
||||
search_query = "user jordan"
|
||||
snippets = SnippetGenerator.generate(sample_text, search_query)
|
||||
|
||||
assert len(snippets) == 1, (
|
||||
f"Expected exactly 1 snippet for '{search_query}' when exact phrase exists, "
|
||||
f"got {len(snippets)}. Should ignore individual word occurrences."
|
||||
)
|
||||
|
||||
snippet = snippets[0]
|
||||
|
||||
assert (
|
||||
search_query in snippet.lower()
|
||||
), f"Snippet should contain the exact phrase '{search_query}'. Got: {snippet}"
|
||||
|
||||
assert (
|
||||
"jordan mentions" in snippet.lower()
|
||||
), f"Snippet should include context before the exact phrase match. Got: {snippet}"
|
||||
|
||||
assert (
|
||||
"alice" not in snippet.lower()
|
||||
), f"Snippet should not include separate occurrences of individual words. Got: {snippet}"
|
||||
|
||||
text_2 = """The alpha version was released.
|
||||
Beta testing started yesterday.
|
||||
The alpha beta integration is complete."""
|
||||
|
||||
snippets_2 = SnippetGenerator.generate(text_2, "alpha beta")
|
||||
assert len(snippets_2) == 1, "Should return 1 snippet for exact phrase match"
|
||||
assert "alpha beta" in snippets_2[0].lower(), "Should contain exact phrase"
|
||||
assert (
|
||||
"version" not in snippets_2[0].lower()
|
||||
), "Should not include first separate occurrence"
|
||||
|
||||
|
||||
class TestSnippetGenerationEnhanced:
|
||||
"""Additional snippet generation tests from test_search_enhancements.py."""
|
||||
|
||||
def test_snippet_generation_from_webvtt(self):
|
||||
"""Test snippet generation from WebVTT content."""
|
||||
webvtt_content = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:05.000
|
||||
This is the beginning of the transcript
|
||||
|
||||
00:00:05.000 --> 00:00:10.000
|
||||
The search term appears here in the middle
|
||||
|
||||
00:00:10.000 --> 00:00:15.000
|
||||
And this is the end of the content"""
|
||||
|
||||
plain_text = WebVTTProcessor.extract_text(webvtt_content)
|
||||
snippets = SnippetGenerator.generate(plain_text, "search term")
|
||||
|
||||
assert len(snippets) > 0
|
||||
assert any("search term" in snippet.lower() for snippet in snippets)
|
||||
|
||||
def test_extract_webvtt_text_with_malformed_variations(self):
|
||||
"""Test WebVTT extraction with various malformed content."""
|
||||
malformed_vtt = "This is not valid WebVTT content"
|
||||
result = WebVTTProcessor.extract_text(malformed_vtt)
|
||||
assert result == ""
|
||||
|
||||
partial_vtt = "WEBVTT\nNo timestamps here"
|
||||
result = WebVTTProcessor.extract_text(partial_vtt)
|
||||
assert result == "" or "No timestamps" not in result
|
||||
|
||||
|
||||
class TestPureFunctions:
|
||||
"""Test the pure functions extracted for functional programming."""
|
||||
|
||||
def test_find_all_matches(self):
|
||||
"""Test finding all match positions in text."""
|
||||
text = "Python is great. Python is powerful. I love Python."
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "Python"))
|
||||
assert matches == [0, 17, 44]
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "python"))
|
||||
assert matches == [0, 17, 44]
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches(text, "Ruby"))
|
||||
assert matches == []
|
||||
|
||||
matches = list(SnippetGenerator.find_all_matches("", "test"))
|
||||
assert matches == []
|
||||
matches = list(SnippetGenerator.find_all_matches("test", ""))
|
||||
assert matches == []
|
||||
|
||||
def test_create_snippet(self):
|
||||
"""Test creating a snippet from a match position."""
|
||||
text = "This is a long text with the word Python in the middle and more text after."
|
||||
|
||||
snippet = SnippetGenerator.create_snippet(text, 35, max_length=150)
|
||||
assert "Python" in snippet.text()
|
||||
assert snippet.start >= 0
|
||||
assert snippet.end <= len(text)
|
||||
assert isinstance(snippet, SnippetCandidate)
|
||||
|
||||
assert len(snippet.text()) > 0
|
||||
assert snippet.start <= snippet.end
|
||||
|
||||
long_text = "A" * 200
|
||||
snippet = SnippetGenerator.create_snippet(long_text, 100, max_length=50)
|
||||
assert snippet.text().startswith("...")
|
||||
assert snippet.text().endswith("...")
|
||||
|
||||
snippet = SnippetGenerator.create_snippet("short text", 0, max_length=100)
|
||||
assert snippet.start == 0
|
||||
assert "short text" in snippet.text()
|
||||
|
||||
def test_filter_non_overlapping(self):
|
||||
"""Test filtering overlapping snippets."""
|
||||
candidates = [
|
||||
SnippetCandidate(_text="First snippet", start=0, _original_text_length=100),
|
||||
SnippetCandidate(_text="Overlapping", start=10, _original_text_length=100),
|
||||
SnippetCandidate(
|
||||
_text="Third snippet", start=40, _original_text_length=100
|
||||
),
|
||||
SnippetCandidate(
|
||||
_text="Fourth snippet", start=65, _original_text_length=100
|
||||
),
|
||||
]
|
||||
|
||||
filtered = list(SnippetGenerator.filter_non_overlapping(iter(candidates)))
|
||||
assert filtered == [
|
||||
"First snippet...",
|
||||
"...Third snippet...",
|
||||
"...Fourth snippet...",
|
||||
]
|
||||
|
||||
filtered = list(SnippetGenerator.filter_non_overlapping(iter([])))
|
||||
assert filtered == []
|
||||
|
||||
def test_generate_integration(self):
|
||||
"""Test the main SnippetGenerator.generate function."""
|
||||
text = "Machine learning is amazing. Machine learning transforms data. Learn machine learning today."
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine learning")
|
||||
assert len(snippets) <= 3
|
||||
assert all("machine learning" in s.lower() for s in snippets)
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine learning", max_snippets=2)
|
||||
assert len(snippets) <= 2
|
||||
|
||||
snippets = SnippetGenerator.generate(text, "machine vision")
|
||||
assert len(snippets) > 0
|
||||
assert any("machine" in s.lower() for s in snippets)
|
||||
|
||||
def test_extract_webvtt_text_basic(self):
|
||||
"""Test WebVTT text extraction (basic test, full tests exist elsewhere)."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Hello world
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
This is a test"""
|
||||
|
||||
result = WebVTTProcessor.extract_text(webvtt)
|
||||
assert "Hello world" in result
|
||||
assert "This is a test" in result
|
||||
|
||||
# Test empty input
|
||||
assert WebVTTProcessor.extract_text("") == ""
|
||||
assert WebVTTProcessor.extract_text(None) == ""
|
||||
|
||||
def test_generate_webvtt_snippets(self):
|
||||
"""Test generating snippets from WebVTT content."""
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Python programming is great
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
Learn Python today"""
|
||||
|
||||
snippets = WebVTTProcessor.generate_snippets(webvtt, "Python")
|
||||
assert len(snippets) > 0
|
||||
assert any("Python" in s for s in snippets)
|
||||
|
||||
snippets = WebVTTProcessor.generate_snippets("", "Python")
|
||||
assert snippets == []
|
||||
|
||||
def test_from_summary(self):
|
||||
"""Test generating snippets from summary text."""
|
||||
summary = "This meeting discussed Python development and machine learning applications."
|
||||
|
||||
snippets = SnippetGenerator.from_summary(summary, "Python")
|
||||
assert len(snippets) > 0
|
||||
assert any("Python" in s for s in snippets)
|
||||
|
||||
long_summary = "Python " * 20
|
||||
snippets = SnippetGenerator.from_summary(long_summary, "Python")
|
||||
assert len(snippets) <= 2
|
||||
|
||||
def test_combine_sources(self):
|
||||
"""Test combining snippets from multiple sources."""
|
||||
summary = "Python is a great programming language."
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Learn Python programming
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
Python is powerful"""
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, webvtt, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) <= 3
|
||||
assert len(snippets) > 0
|
||||
assert total_count > 0
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, None, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) > 0
|
||||
assert all("Python" in s for s in snippets)
|
||||
assert total_count == 1
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
None, webvtt, "Python", max_total=3
|
||||
)
|
||||
assert len(snippets) > 0
|
||||
assert total_count == 2
|
||||
|
||||
long_summary = "Python " * 10
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
long_summary, webvtt, "Python", max_total=2
|
||||
)
|
||||
assert len(snippets) == 2
|
||||
assert total_count >= 10
|
||||
|
||||
def test_match_counting_sum_logic(self):
|
||||
"""Test that match counting correctly sums matches from both sources."""
|
||||
summary = "data science uses data analysis and data mining techniques"
|
||||
webvtt = """WEBVTT
|
||||
|
||||
00:00:00.000 --> 00:00:02.000
|
||||
Big data processing
|
||||
|
||||
00:00:02.000 --> 00:00:04.000
|
||||
data visualization and data storage"""
|
||||
|
||||
snippets, total_count = SnippetGenerator.combine_sources(
|
||||
summary, webvtt, "data", max_total=3
|
||||
)
|
||||
assert total_count == 6
|
||||
assert len(snippets) <= 3
|
||||
|
||||
summary_snippets, summary_count = SnippetGenerator.combine_sources(
|
||||
summary, None, "data", max_total=3
|
||||
)
|
||||
assert summary_count == 3
|
||||
|
||||
webvtt_snippets, webvtt_count = SnippetGenerator.combine_sources(
|
||||
None, webvtt, "data", max_total=3
|
||||
)
|
||||
assert webvtt_count == 3
|
||||
|
||||
snippets_empty, count_empty = SnippetGenerator.combine_sources(
|
||||
None, None, "data", max_total=3
|
||||
)
|
||||
assert snippets_empty == []
|
||||
assert count_empty == 0
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases for the pure functions."""
|
||||
text = "Test with special: @#$%^&*() characters"
|
||||
snippets = SnippetGenerator.generate(text, "@#$%")
|
||||
assert len(snippets) > 0
|
||||
|
||||
long_query = "a" * 100
|
||||
snippets = SnippetGenerator.generate("Some text", long_query)
|
||||
assert snippets == []
|
||||
|
||||
text = "Unicode test: café, naïve, 日本語"
|
||||
snippets = SnippetGenerator.generate(text, "café")
|
||||
assert len(snippets) > 0
|
||||
assert "café" in snippets[0]
|
||||
|
||||
872
server/uv.lock
generated
872
server/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,67 +1,26 @@
|
||||
import React, { useEffect } from "react";
|
||||
import React from "react";
|
||||
import { Pagination, IconButton, ButtonGroup } from "@chakra-ui/react";
|
||||
import { LuChevronLeft, LuChevronRight } from "react-icons/lu";
|
||||
|
||||
// explicitly 1-based to prevent +/-1-confusion errors
|
||||
export const FIRST_PAGE = 1 as PaginationPage;
|
||||
export const parsePaginationPage = (
|
||||
page: number,
|
||||
):
|
||||
| {
|
||||
value: PaginationPage;
|
||||
}
|
||||
| {
|
||||
error: string;
|
||||
} => {
|
||||
if (page < FIRST_PAGE)
|
||||
return {
|
||||
error: "Page must be greater than 0",
|
||||
};
|
||||
if (!Number.isInteger(page))
|
||||
return {
|
||||
error: "Page must be an integer",
|
||||
};
|
||||
return {
|
||||
value: page as PaginationPage,
|
||||
};
|
||||
};
|
||||
export type PaginationPage = number & { __brand: "PaginationPage" };
|
||||
export const PaginationPage = (page: number): PaginationPage => {
|
||||
const v = parsePaginationPage(page);
|
||||
if ("error" in v) throw new Error(v.error);
|
||||
return v.value;
|
||||
};
|
||||
|
||||
export const paginationPageTo0Based = (page: PaginationPage): number =>
|
||||
page - FIRST_PAGE;
|
||||
|
||||
type PaginationProps = {
|
||||
page: PaginationPage;
|
||||
setPage: (page: PaginationPage) => void;
|
||||
page: number;
|
||||
setPage: (page: number) => void;
|
||||
total: number;
|
||||
size: number;
|
||||
};
|
||||
|
||||
export const totalPages = (total: number, size: number) => {
|
||||
return Math.ceil(total / size);
|
||||
};
|
||||
|
||||
export default function PaginationComponent(props: PaginationProps) {
|
||||
const { page, setPage, total, size } = props;
|
||||
useEffect(() => {
|
||||
if (page > totalPages(total, size)) {
|
||||
console.error(
|
||||
`Page number (${page}) is greater than total pages (${totalPages}) in pagination`,
|
||||
);
|
||||
}
|
||||
}, [page, totalPages(total, size)]);
|
||||
const totalPages = Math.ceil(total / size);
|
||||
|
||||
if (totalPages <= 1) return null;
|
||||
|
||||
return (
|
||||
<Pagination.Root
|
||||
count={total}
|
||||
pageSize={size}
|
||||
page={page}
|
||||
onPageChange={(details) => setPage(PaginationPage(details.page))}
|
||||
onPageChange={(details) => setPage(details.page)}
|
||||
style={{ display: "flex", justifyContent: "center" }}
|
||||
>
|
||||
<ButtonGroup variant="ghost" size="xs">
|
||||
|
||||
34
www/app/(app)/browse/_components/SearchBar.tsx
Normal file
34
www/app/(app)/browse/_components/SearchBar.tsx
Normal file
@@ -0,0 +1,34 @@
|
||||
import React, { useState } from "react";
|
||||
import { Flex, Input, Button } from "@chakra-ui/react";
|
||||
|
||||
interface SearchBarProps {
|
||||
onSearch: (searchTerm: string) => void;
|
||||
}
|
||||
|
||||
export default function SearchBar({ onSearch }: SearchBarProps) {
|
||||
const [searchInputValue, setSearchInputValue] = useState("");
|
||||
|
||||
const handleSearch = () => {
|
||||
onSearch(searchInputValue);
|
||||
};
|
||||
|
||||
const handleKeyDown = (event: React.KeyboardEvent) => {
|
||||
if (event.key === "Enter") {
|
||||
handleSearch();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex alignItems="center">
|
||||
<Input
|
||||
placeholder="Search transcriptions..."
|
||||
value={searchInputValue}
|
||||
onChange={(e) => setSearchInputValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
/>
|
||||
<Button ml={2} onClick={handleSearch}>
|
||||
Search
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
@@ -4,8 +4,8 @@ import { LuMenu, LuTrash, LuRotateCw } from "react-icons/lu";
|
||||
|
||||
interface TranscriptActionsMenuProps {
|
||||
transcriptId: string;
|
||||
onDelete: (transcriptId: string) => void;
|
||||
onReprocess: (transcriptId: string) => void;
|
||||
onDelete: (transcriptId: string) => (e: any) => void;
|
||||
onReprocess: (transcriptId: string) => (e: any) => void;
|
||||
}
|
||||
|
||||
export default function TranscriptActionsMenu({
|
||||
@@ -24,17 +24,11 @@ export default function TranscriptActionsMenu({
|
||||
<Menu.Content>
|
||||
<Menu.Item
|
||||
value="reprocess"
|
||||
onClick={() => onReprocess(transcriptId)}
|
||||
onClick={(e) => onReprocess(transcriptId)(e)}
|
||||
>
|
||||
<LuRotateCw /> Reprocess
|
||||
</Menu.Item>
|
||||
<Menu.Item
|
||||
value="delete"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onDelete(transcriptId);
|
||||
}}
|
||||
>
|
||||
<Menu.Item value="delete" onClick={(e) => onDelete(transcriptId)(e)}>
|
||||
<LuTrash /> Delete
|
||||
</Menu.Item>
|
||||
</Menu.Content>
|
||||
|
||||
@@ -1,290 +1,27 @@
|
||||
import React, { useState } from "react";
|
||||
import {
|
||||
Box,
|
||||
Stack,
|
||||
Text,
|
||||
Flex,
|
||||
Link,
|
||||
Spinner,
|
||||
Badge,
|
||||
HStack,
|
||||
VStack,
|
||||
} from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
import { Box, Stack, Text, Flex, Link, Spinner } from "@chakra-ui/react";
|
||||
import NextLink from "next/link";
|
||||
import { GetTranscriptMinimal } from "../../../api";
|
||||
import { formatTimeMs, formatLocalDate } from "../../../lib/time";
|
||||
import TranscriptStatusIcon from "./TranscriptStatusIcon";
|
||||
import TranscriptActionsMenu from "./TranscriptActionsMenu";
|
||||
import {
|
||||
highlightMatches,
|
||||
generateTextFragment,
|
||||
} from "../../../lib/textHighlight";
|
||||
import { SearchResult } from "../../../api";
|
||||
|
||||
interface TranscriptCardsProps {
|
||||
results: SearchResult[];
|
||||
query: string;
|
||||
isLoading?: boolean;
|
||||
onDelete: (transcriptId: string) => void;
|
||||
onReprocess: (transcriptId: string) => void;
|
||||
}
|
||||
|
||||
function highlightText(text: string, query: string): React.ReactNode {
|
||||
if (!query) return text;
|
||||
|
||||
const matches = highlightMatches(text, query);
|
||||
|
||||
if (matches.length === 0) return text;
|
||||
|
||||
// Sort matches by index to process them in order
|
||||
const sortedMatches = [...matches].sort((a, b) => a.index - b.index);
|
||||
|
||||
const parts: React.ReactNode[] = [];
|
||||
let lastIndex = 0;
|
||||
|
||||
sortedMatches.forEach((match, i) => {
|
||||
// Add text before the match
|
||||
if (match.index > lastIndex) {
|
||||
parts.push(
|
||||
<Text as="span" key={`text-${i}`} display="inline">
|
||||
{text.slice(lastIndex, match.index)}
|
||||
</Text>,
|
||||
);
|
||||
}
|
||||
|
||||
// Add the highlighted match
|
||||
parts.push(
|
||||
<Text
|
||||
as="mark"
|
||||
key={`match-${i}`}
|
||||
bg="yellow.200"
|
||||
px={0.5}
|
||||
display="inline"
|
||||
>
|
||||
{match.match}
|
||||
</Text>,
|
||||
);
|
||||
|
||||
lastIndex = match.index + match.match.length;
|
||||
});
|
||||
|
||||
// Add remaining text after last match
|
||||
if (lastIndex < text.length) {
|
||||
parts.push(
|
||||
<Text as="span" key={`text-end`} display="inline">
|
||||
{text.slice(lastIndex)}
|
||||
</Text>,
|
||||
);
|
||||
}
|
||||
|
||||
return parts;
|
||||
}
|
||||
|
||||
const transcriptHref = (
|
||||
transcriptId: string,
|
||||
mainSnippet: string,
|
||||
query: string,
|
||||
): `/transcripts/${string}` => {
|
||||
const urlTextFragment = mainSnippet
|
||||
? generateTextFragment(mainSnippet, query)
|
||||
: null;
|
||||
const urlTextFragmentWithHash = urlTextFragment
|
||||
? `#${urlTextFragment.k}=${encodeURIComponent(urlTextFragment.v)}`
|
||||
: "";
|
||||
return `/transcripts/${transcriptId}${urlTextFragmentWithHash}`;
|
||||
};
|
||||
|
||||
// note that it's strongly tied to search logic - in case you want to use it independently, refactor
|
||||
function TranscriptCard({
|
||||
result,
|
||||
query,
|
||||
onDelete,
|
||||
onReprocess,
|
||||
}: {
|
||||
result: SearchResult;
|
||||
query: string;
|
||||
onDelete: (transcriptId: string) => void;
|
||||
onReprocess: (transcriptId: string) => void;
|
||||
}) {
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
|
||||
const mainSnippet = result.search_snippets[0];
|
||||
const additionalSnippets = result.search_snippets.slice(1);
|
||||
const totalMatches = result.total_match_count || 0;
|
||||
const snippetsShown = result.search_snippets.length;
|
||||
const remainingMatches = totalMatches - snippetsShown;
|
||||
const hasAdditionalSnippets = additionalSnippets.length > 0;
|
||||
const resultTitle = result.title || "Unnamed Transcript";
|
||||
|
||||
const formattedDuration = result.duration
|
||||
? formatTimeMs(result.duration)
|
||||
: "N/A";
|
||||
const formattedDate = formatLocalDate(result.created_at);
|
||||
const source =
|
||||
result.source_kind === "room"
|
||||
? result.room_name || result.room_id
|
||||
: result.source_kind;
|
||||
|
||||
const handleExpandClick = (e: React.MouseEvent) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsExpanded(!isExpanded);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box borderWidth={1} p={4} borderRadius="md" fontSize="sm">
|
||||
<Flex justify="space-between" alignItems="flex-start" gap="2">
|
||||
<Box>
|
||||
<TranscriptStatusIcon status={result.status} />
|
||||
</Box>
|
||||
<Box flex="1">
|
||||
{/* Title with highlighting and text fragment for deep linking */}
|
||||
<Link
|
||||
as={NextLink}
|
||||
href={transcriptHref(result.id, mainSnippet, query)}
|
||||
fontWeight="600"
|
||||
display="block"
|
||||
mb={2}
|
||||
>
|
||||
{highlightText(resultTitle, query)}
|
||||
</Link>
|
||||
|
||||
{/* Metadata - Horizontal on desktop, vertical on mobile */}
|
||||
<Flex
|
||||
direction={{ base: "column", md: "row" }}
|
||||
gap={{ base: 1, md: 2 }}
|
||||
fontSize="xs"
|
||||
color="gray.600"
|
||||
flexWrap="wrap"
|
||||
align={{ base: "flex-start", md: "center" }}
|
||||
>
|
||||
<Flex align="center" gap={1}>
|
||||
<Text fontWeight="medium" color="gray.500">
|
||||
Source:
|
||||
</Text>
|
||||
<Text>{source}</Text>
|
||||
</Flex>
|
||||
<Text display={{ base: "none", md: "block" }} color="gray.400">
|
||||
•
|
||||
</Text>
|
||||
<Flex align="center" gap={1}>
|
||||
<Text fontWeight="medium" color="gray.500">
|
||||
Date:
|
||||
</Text>
|
||||
<Text>{formattedDate}</Text>
|
||||
</Flex>
|
||||
<Text display={{ base: "none", md: "block" }} color="gray.400">
|
||||
•
|
||||
</Text>
|
||||
<Flex align="center" gap={1}>
|
||||
<Text fontWeight="medium" color="gray.500">
|
||||
Duration:
|
||||
</Text>
|
||||
<Text>{formattedDuration}</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
{/* Search Results Section - only show when searching */}
|
||||
{mainSnippet && (
|
||||
<>
|
||||
{/* Main Snippet */}
|
||||
<Box
|
||||
mt={3}
|
||||
p={2}
|
||||
bg="gray.50"
|
||||
borderLeft="2px solid"
|
||||
borderLeftColor="blue.400"
|
||||
borderRadius="sm"
|
||||
fontSize="xs"
|
||||
>
|
||||
<Text color="gray.700">
|
||||
{highlightText(mainSnippet, query)}
|
||||
</Text>
|
||||
</Box>
|
||||
|
||||
{hasAdditionalSnippets && (
|
||||
<>
|
||||
<Flex
|
||||
mt={2}
|
||||
p={2}
|
||||
bg="blue.50"
|
||||
borderRadius="sm"
|
||||
cursor="pointer"
|
||||
onClick={handleExpandClick}
|
||||
_hover={{ bg: "blue.100" }}
|
||||
align="center"
|
||||
justify="space-between"
|
||||
>
|
||||
<HStack gap={2}>
|
||||
<Badge
|
||||
bg="blue.500"
|
||||
color="white"
|
||||
fontSize="xs"
|
||||
px={2}
|
||||
borderRadius="full"
|
||||
>
|
||||
{remainingMatches > 0
|
||||
? `${additionalSnippets.length + remainingMatches}+`
|
||||
: additionalSnippets.length}
|
||||
</Badge>
|
||||
<Text fontSize="xs" color="blue.600" fontWeight="medium">
|
||||
more{" "}
|
||||
{additionalSnippets.length + remainingMatches === 1
|
||||
? "match"
|
||||
: "matches"}
|
||||
{remainingMatches > 0 &&
|
||||
` (${additionalSnippets.length} shown)`}
|
||||
</Text>
|
||||
</HStack>
|
||||
<Text fontSize="xs" color="blue.600">
|
||||
{isExpanded ? "▲" : "▼"}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
{/* Additional Snippets */}
|
||||
{isExpanded && (
|
||||
<VStack align="stretch" gap={2} mt={2}>
|
||||
{additionalSnippets.map((snippet, index) => (
|
||||
<Box
|
||||
key={index}
|
||||
p={2}
|
||||
bg="gray.50"
|
||||
borderLeft="2px solid"
|
||||
borderLeftColor="gray.300"
|
||||
borderRadius="sm"
|
||||
fontSize="xs"
|
||||
>
|
||||
<Text color="gray.700">
|
||||
{highlightText(snippet, query)}
|
||||
</Text>
|
||||
</Box>
|
||||
))}
|
||||
</VStack>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Box>
|
||||
<TranscriptActionsMenu
|
||||
transcriptId={result.id}
|
||||
onDelete={onDelete}
|
||||
onReprocess={onReprocess}
|
||||
/>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
transcripts: GetTranscriptMinimal[];
|
||||
onDelete: (transcriptId: string) => (e: any) => void;
|
||||
onReprocess: (transcriptId: string) => (e: any) => void;
|
||||
loading?: boolean;
|
||||
}
|
||||
|
||||
export default function TranscriptCards({
|
||||
results,
|
||||
query,
|
||||
isLoading,
|
||||
transcripts,
|
||||
onDelete,
|
||||
onReprocess,
|
||||
loading,
|
||||
}: TranscriptCardsProps) {
|
||||
return (
|
||||
<Box position="relative">
|
||||
{isLoading && (
|
||||
<Box display={{ base: "block", lg: "none" }} position="relative">
|
||||
{loading && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0}
|
||||
@@ -300,19 +37,48 @@ export default function TranscriptCards({
|
||||
</Flex>
|
||||
)}
|
||||
<Box
|
||||
opacity={isLoading ? 0.9 : 1}
|
||||
pointerEvents={isLoading ? "none" : "auto"}
|
||||
opacity={loading ? 0.9 : 1}
|
||||
pointerEvents={loading ? "none" : "auto"}
|
||||
transition="opacity 0.2s ease-in-out"
|
||||
>
|
||||
<Stack gap={3}>
|
||||
{results.map((result) => (
|
||||
<TranscriptCard
|
||||
key={result.id}
|
||||
result={result}
|
||||
query={query}
|
||||
onDelete={onDelete}
|
||||
onReprocess={onReprocess}
|
||||
/>
|
||||
<Stack gap={2}>
|
||||
{transcripts.map((item) => (
|
||||
<Box
|
||||
key={item.id}
|
||||
borderWidth={1}
|
||||
p={4}
|
||||
borderRadius="md"
|
||||
fontSize="sm"
|
||||
>
|
||||
<Flex justify="space-between" alignItems="flex-start" gap="2">
|
||||
<Box>
|
||||
<TranscriptStatusIcon status={item.status} />
|
||||
</Box>
|
||||
<Box flex="1">
|
||||
<Link
|
||||
as={NextLink}
|
||||
href={`/transcripts/${item.id}`}
|
||||
fontWeight="600"
|
||||
display="block"
|
||||
>
|
||||
{item.title || "Unnamed Transcript"}
|
||||
</Link>
|
||||
<Text>
|
||||
Source:{" "}
|
||||
{item.source_kind === "room"
|
||||
? item.room_name
|
||||
: item.source_kind}
|
||||
</Text>
|
||||
<Text>Date: {formatLocalDate(item.created_at)}</Text>
|
||||
<Text>Duration: {formatTimeMs(item.duration)}</Text>
|
||||
</Box>
|
||||
<TranscriptActionsMenu
|
||||
transcriptId={item.id}
|
||||
onDelete={onDelete}
|
||||
onReprocess={onReprocess}
|
||||
/>
|
||||
</Flex>
|
||||
</Box>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
|
||||
99
www/app/(app)/browse/_components/TranscriptTable.tsx
Normal file
99
www/app/(app)/browse/_components/TranscriptTable.tsx
Normal file
@@ -0,0 +1,99 @@
|
||||
import React from "react";
|
||||
import { Box, Table, Link, Flex, Spinner } from "@chakra-ui/react";
|
||||
import NextLink from "next/link";
|
||||
import { GetTranscriptMinimal } from "../../../api";
|
||||
import { formatTimeMs, formatLocalDate } from "../../../lib/time";
|
||||
import TranscriptStatusIcon from "./TranscriptStatusIcon";
|
||||
import TranscriptActionsMenu from "./TranscriptActionsMenu";
|
||||
|
||||
interface TranscriptTableProps {
|
||||
transcripts: GetTranscriptMinimal[];
|
||||
onDelete: (transcriptId: string) => (e: any) => void;
|
||||
onReprocess: (transcriptId: string) => (e: any) => void;
|
||||
loading?: boolean;
|
||||
}
|
||||
|
||||
export default function TranscriptTable({
|
||||
transcripts,
|
||||
onDelete,
|
||||
onReprocess,
|
||||
loading,
|
||||
}: TranscriptTableProps) {
|
||||
return (
|
||||
<Box display={{ base: "none", lg: "block" }} position="relative">
|
||||
{loading && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
align="center"
|
||||
justify="center"
|
||||
>
|
||||
<Spinner size="xl" color="gray.700" />
|
||||
</Flex>
|
||||
)}
|
||||
<Box
|
||||
opacity={loading ? 0.9 : 1}
|
||||
pointerEvents={loading ? "none" : "auto"}
|
||||
transition="opacity 0.2s ease-in-out"
|
||||
>
|
||||
<Table.Root>
|
||||
<Table.Header>
|
||||
<Table.Row>
|
||||
<Table.ColumnHeader
|
||||
width="16px"
|
||||
fontWeight="600"
|
||||
></Table.ColumnHeader>
|
||||
<Table.ColumnHeader width="400px" fontWeight="600">
|
||||
Transcription Title
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader width="150px" fontWeight="600">
|
||||
Source
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader width="200px" fontWeight="600">
|
||||
Date
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader width="100px" fontWeight="600">
|
||||
Duration
|
||||
</Table.ColumnHeader>
|
||||
<Table.ColumnHeader
|
||||
width="50px"
|
||||
fontWeight="600"
|
||||
></Table.ColumnHeader>
|
||||
</Table.Row>
|
||||
</Table.Header>
|
||||
<Table.Body>
|
||||
{transcripts.map((item) => (
|
||||
<Table.Row key={item.id}>
|
||||
<Table.Cell>
|
||||
<TranscriptStatusIcon status={item.status} />
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<Link as={NextLink} href={`/transcripts/${item.id}`}>
|
||||
{item.title || "Unnamed Transcript"}
|
||||
</Link>
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
{item.source_kind === "room"
|
||||
? item.room_name
|
||||
: item.source_kind}
|
||||
</Table.Cell>
|
||||
<Table.Cell>{formatLocalDate(item.created_at)}</Table.Cell>
|
||||
<Table.Cell>{formatTimeMs(item.duration)}</Table.Cell>
|
||||
<Table.Cell>
|
||||
<TranscriptActionsMenu
|
||||
transcriptId={item.id}
|
||||
onDelete={onDelete}
|
||||
onReprocess={onReprocess}
|
||||
/>
|
||||
</Table.Cell>
|
||||
</Table.Row>
|
||||
))}
|
||||
</Table.Body>
|
||||
</Table.Root>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
@@ -1,264 +1,33 @@
|
||||
"use client";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
Flex,
|
||||
Spinner,
|
||||
Heading,
|
||||
Text,
|
||||
Link,
|
||||
Box,
|
||||
Stack,
|
||||
Input,
|
||||
Button,
|
||||
IconButton,
|
||||
} from "@chakra-ui/react";
|
||||
import {
|
||||
useQueryState,
|
||||
parseAsString,
|
||||
parseAsInteger,
|
||||
parseAsStringLiteral,
|
||||
} from "nuqs";
|
||||
import { LuX } from "react-icons/lu";
|
||||
import { useSearchTranscripts } from "../transcripts/useSearchTranscripts";
|
||||
import { Flex, Spinner, Heading, Text, Link } from "@chakra-ui/react";
|
||||
import useTranscriptList from "../transcripts/useTranscriptList";
|
||||
import useSessionUser from "../../lib/useSessionUser";
|
||||
import { Room, SourceKind, SearchResult, $SourceKind } from "../../api";
|
||||
import { Room } from "../../api";
|
||||
import Pagination from "./_components/Pagination";
|
||||
import useApi from "../../lib/useApi";
|
||||
import { useError } from "../../(errors)/errorContext";
|
||||
import { SourceKind } from "../../api";
|
||||
import FilterSidebar from "./_components/FilterSidebar";
|
||||
import Pagination, {
|
||||
FIRST_PAGE,
|
||||
PaginationPage,
|
||||
parsePaginationPage,
|
||||
totalPages as getTotalPages,
|
||||
} from "./_components/Pagination";
|
||||
import SearchBar from "./_components/SearchBar";
|
||||
import TranscriptTable from "./_components/TranscriptTable";
|
||||
import TranscriptCards from "./_components/TranscriptCards";
|
||||
import DeleteTranscriptDialog from "./_components/DeleteTranscriptDialog";
|
||||
import { formatLocalDate } from "../../lib/time";
|
||||
import { RECORD_A_MEETING_URL } from "../../api/urls";
|
||||
|
||||
const SEARCH_FORM_QUERY_INPUT_NAME = "query" as const;
|
||||
|
||||
const usePrefetchRooms = (setRooms: (rooms: Room[]) => void): void => {
|
||||
const { setError } = useError();
|
||||
const api = useApi();
|
||||
useEffect(() => {
|
||||
if (!api) return;
|
||||
api
|
||||
.v1RoomsList({ page: 1 })
|
||||
.then((rooms) => setRooms(rooms.items))
|
||||
.catch((err) => setError(err, "There was an error fetching the rooms"));
|
||||
}, [api, setError]);
|
||||
};
|
||||
|
||||
const SearchForm: React.FC<{
|
||||
setPage: (page: PaginationPage) => void;
|
||||
sourceKind: SourceKind | null;
|
||||
roomId: string | null;
|
||||
setSourceKind: (sourceKind: SourceKind | null) => void;
|
||||
setRoomId: (roomId: string | null) => void;
|
||||
rooms: Room[];
|
||||
searchQuery: string | null;
|
||||
setSearchQuery: (query: string | null) => void;
|
||||
}> = ({
|
||||
setPage,
|
||||
sourceKind,
|
||||
roomId,
|
||||
setRoomId,
|
||||
setSourceKind,
|
||||
rooms,
|
||||
searchQuery,
|
||||
setSearchQuery,
|
||||
}) => {
|
||||
// to keep the search input controllable + more fine grained control (urlSearchQuery is updated on submits)
|
||||
const [searchInputValue, setSearchInputValue] = useState(searchQuery || "");
|
||||
const handleSearchQuerySubmit = async (d: FormData) => {
|
||||
await setSearchQuery((d.get(SEARCH_FORM_QUERY_INPUT_NAME) as string) || "");
|
||||
};
|
||||
|
||||
const handleClearSearch = () => {
|
||||
setSearchInputValue("");
|
||||
setSearchQuery(null);
|
||||
setPage(FIRST_PAGE);
|
||||
};
|
||||
return (
|
||||
<Stack gap={2}>
|
||||
<form action={handleSearchQuerySubmit}>
|
||||
<Flex alignItems="center">
|
||||
<Box position="relative" flex="1">
|
||||
<Input
|
||||
placeholder="Search transcriptions..."
|
||||
value={searchInputValue}
|
||||
onChange={(e) => setSearchInputValue(e.target.value)}
|
||||
name={SEARCH_FORM_QUERY_INPUT_NAME}
|
||||
pr={searchQuery ? "2.5rem" : undefined}
|
||||
/>
|
||||
{searchQuery && (
|
||||
<IconButton
|
||||
aria-label="Clear search"
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
onClick={handleClearSearch}
|
||||
position="absolute"
|
||||
right="0.25rem"
|
||||
top="50%"
|
||||
transform="translateY(-50%)"
|
||||
_hover={{ bg: "gray.100" }}
|
||||
>
|
||||
<LuX />
|
||||
</IconButton>
|
||||
)}
|
||||
</Box>
|
||||
<Button ml={2} type="submit">
|
||||
Search
|
||||
</Button>
|
||||
</Flex>
|
||||
</form>
|
||||
<UnderSearchFormFilterIndicators
|
||||
sourceKind={sourceKind}
|
||||
roomId={roomId}
|
||||
setSourceKind={setSourceKind}
|
||||
setRoomId={setRoomId}
|
||||
rooms={rooms}
|
||||
/>
|
||||
</Stack>
|
||||
);
|
||||
};
|
||||
|
||||
const UnderSearchFormFilterIndicators: React.FC<{
|
||||
sourceKind: SourceKind | null;
|
||||
roomId: string | null;
|
||||
setSourceKind: (sourceKind: SourceKind | null) => void;
|
||||
setRoomId: (roomId: string | null) => void;
|
||||
rooms: Room[];
|
||||
}> = ({ sourceKind, roomId, setRoomId, setSourceKind, rooms }) => {
|
||||
return (
|
||||
<>
|
||||
{(sourceKind || roomId) && (
|
||||
<Flex gap={2} flexWrap="wrap" align="center">
|
||||
<Text fontSize="sm" color="gray.600">
|
||||
Active filters:
|
||||
</Text>
|
||||
{sourceKind && (
|
||||
<Flex
|
||||
align="center"
|
||||
px={2}
|
||||
py={1}
|
||||
bg="blue.100"
|
||||
borderRadius="md"
|
||||
fontSize="xs"
|
||||
gap={1}
|
||||
>
|
||||
<Text>
|
||||
{roomId
|
||||
? `Room: ${
|
||||
rooms.find((r) => r.id === roomId)?.name || roomId
|
||||
}`
|
||||
: `Source: ${sourceKind}`}
|
||||
</Text>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
minW="auto"
|
||||
h="auto"
|
||||
p="1px"
|
||||
onClick={() => {
|
||||
setSourceKind(null);
|
||||
// TODO questionable
|
||||
setRoomId(null);
|
||||
}}
|
||||
_hover={{ bg: "blue.200" }}
|
||||
aria-label="Clear filter"
|
||||
>
|
||||
<LuX size={14} />
|
||||
</Button>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const EmptyResult: React.FC<{
|
||||
searchQuery: string;
|
||||
}> = ({ searchQuery }) => {
|
||||
return (
|
||||
<Flex flexDir="column" alignItems="center" justifyContent="center" py={8}>
|
||||
<Text textAlign="center">
|
||||
{searchQuery
|
||||
? `No results found for "${searchQuery}". Try adjusting your search terms.`
|
||||
: "No transcripts found, but you can "}
|
||||
{!searchQuery && (
|
||||
<>
|
||||
<Link href={RECORD_A_MEETING_URL} color="blue.500">
|
||||
record a meeting
|
||||
</Link>
|
||||
{" to get started."}
|
||||
</>
|
||||
)}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default function TranscriptBrowser() {
|
||||
const [urlSearchQuery, setUrlSearchQuery] = useQueryState(
|
||||
"q",
|
||||
parseAsString.withDefault("").withOptions({ shallow: false }),
|
||||
);
|
||||
|
||||
const [urlSourceKind, setUrlSourceKind] = useQueryState(
|
||||
"source",
|
||||
parseAsStringLiteral($SourceKind.enum).withOptions({
|
||||
shallow: false,
|
||||
}),
|
||||
);
|
||||
const [urlRoomId, setUrlRoomId] = useQueryState(
|
||||
"room",
|
||||
parseAsString.withDefault("").withOptions({ shallow: false }),
|
||||
);
|
||||
|
||||
const [urlPage, setPage] = useQueryState(
|
||||
"page",
|
||||
parseAsInteger.withDefault(1).withOptions({ shallow: false }),
|
||||
);
|
||||
|
||||
const [page, _setSafePage] = useState(FIRST_PAGE);
|
||||
|
||||
// safety net
|
||||
useEffect(() => {
|
||||
const maybePage = parsePaginationPage(urlPage);
|
||||
if ("error" in maybePage) {
|
||||
setPage(FIRST_PAGE).then(() => {
|
||||
/*may be called n times we dont care*/
|
||||
});
|
||||
return;
|
||||
}
|
||||
_setSafePage(maybePage.value);
|
||||
}, [urlPage]);
|
||||
|
||||
const [selectedSourceKind, setSelectedSourceKind] =
|
||||
useState<SourceKind | null>(null);
|
||||
const [selectedRoomId, setSelectedRoomId] = useState("");
|
||||
const [rooms, setRooms] = useState<Room[]>([]);
|
||||
|
||||
const pageSize = 20;
|
||||
const {
|
||||
results,
|
||||
totalCount: totalResults,
|
||||
isLoading,
|
||||
reload,
|
||||
} = useSearchTranscripts(
|
||||
urlSearchQuery,
|
||||
{
|
||||
roomIds: urlRoomId ? [urlRoomId] : null,
|
||||
sourceKind: urlSourceKind,
|
||||
},
|
||||
{
|
||||
pageSize,
|
||||
page,
|
||||
},
|
||||
const [page, setPage] = useState(1);
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const { loading, response, refetch } = useTranscriptList(
|
||||
page,
|
||||
selectedSourceKind,
|
||||
selectedRoomId,
|
||||
searchTerm,
|
||||
);
|
||||
|
||||
const totalPages = getTotalPages(totalResults, pageSize);
|
||||
|
||||
const userName = useSessionUser().name;
|
||||
const [deletionLoading, setDeletionLoading] = useState(false);
|
||||
const api = useApi();
|
||||
@@ -266,73 +35,37 @@ export default function TranscriptBrowser() {
|
||||
const cancelRef = React.useRef(null);
|
||||
const [transcriptToDeleteId, setTranscriptToDeleteId] =
|
||||
React.useState<string>();
|
||||
const [deletedItemIds, setDeletedItemIds] = React.useState<string[]>();
|
||||
|
||||
usePrefetchRooms(setRooms);
|
||||
useEffect(() => {
|
||||
setDeletedItemIds([]);
|
||||
}, [page, response]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!api) return;
|
||||
api
|
||||
.v1RoomsList({ page: 1 })
|
||||
.then((rooms) => setRooms(rooms.items))
|
||||
.catch((err) => setError(err, "There was an error fetching the rooms"));
|
||||
}, [api]);
|
||||
|
||||
const handleFilterTranscripts = (
|
||||
sourceKind: SourceKind | null,
|
||||
roomId: string,
|
||||
) => {
|
||||
setUrlSourceKind(sourceKind);
|
||||
setUrlRoomId(roomId);
|
||||
setSelectedSourceKind(sourceKind);
|
||||
setSelectedRoomId(roomId);
|
||||
setPage(1);
|
||||
};
|
||||
|
||||
const onCloseDeletion = () => setTranscriptToDeleteId(undefined);
|
||||
|
||||
const confirmDeleteTranscript = (transcriptId: string) => {
|
||||
if (!api || deletionLoading) return;
|
||||
setDeletionLoading(true);
|
||||
api
|
||||
.v1TranscriptDelete({ transcriptId })
|
||||
.then(() => {
|
||||
setDeletionLoading(false);
|
||||
onCloseDeletion();
|
||||
reload();
|
||||
})
|
||||
.catch((err) => {
|
||||
setDeletionLoading(false);
|
||||
setError(err, "There was an error deleting the transcript");
|
||||
});
|
||||
const handleSearch = (searchTerm: string) => {
|
||||
setPage(1);
|
||||
setSearchTerm(searchTerm);
|
||||
setSelectedSourceKind(null);
|
||||
setSelectedRoomId("");
|
||||
};
|
||||
|
||||
const handleProcessTranscript = (transcriptId: string) => {
|
||||
if (!api) {
|
||||
console.error("API not available on handleProcessTranscript");
|
||||
return;
|
||||
}
|
||||
api
|
||||
.v1TranscriptProcess({ transcriptId })
|
||||
.then((result) => {
|
||||
const status =
|
||||
result && typeof result === "object" && "status" in result
|
||||
? (result as { status: string }).status
|
||||
: undefined;
|
||||
if (status === "already running") {
|
||||
setError(
|
||||
new Error("Processing is already running, please wait"),
|
||||
"Processing is already running, please wait",
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
setError(err, "There was an error processing the transcript");
|
||||
});
|
||||
};
|
||||
|
||||
const transcriptToDelete = results?.find(
|
||||
(i) => i.id === transcriptToDeleteId,
|
||||
);
|
||||
const dialogTitle = transcriptToDelete?.title || "Unnamed Transcript";
|
||||
const dialogDate = transcriptToDelete?.created_at
|
||||
? formatLocalDate(transcriptToDelete.created_at)
|
||||
: undefined;
|
||||
const dialogSource =
|
||||
transcriptToDelete?.source_kind === "room" && transcriptToDelete?.room_id
|
||||
? transcriptToDelete.room_name || transcriptToDelete.room_id
|
||||
: transcriptToDelete?.source_kind;
|
||||
|
||||
if (isLoading && results.length === 0) {
|
||||
if (loading && !response)
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
@@ -343,7 +76,82 @@ export default function TranscriptBrowser() {
|
||||
<Spinner size="xl" />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (!loading && !response)
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
h="100%"
|
||||
>
|
||||
<Text>
|
||||
No transcripts found, but you can
|
||||
<Link href="/transcripts/new" className="underline">
|
||||
record a meeting
|
||||
</Link>
|
||||
to get started.
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
const onCloseDeletion = () => setTranscriptToDeleteId(undefined);
|
||||
|
||||
const confirmDeleteTranscript = (transcriptId: string) => {
|
||||
if (!api || deletionLoading) return;
|
||||
setDeletionLoading(true);
|
||||
api
|
||||
.v1TranscriptDelete({ transcriptId })
|
||||
.then(() => {
|
||||
refetch();
|
||||
setDeletionLoading(false);
|
||||
onCloseDeletion();
|
||||
setDeletedItemIds((prev) =>
|
||||
prev ? [...prev, transcriptId] : [transcriptId],
|
||||
);
|
||||
})
|
||||
.catch((err) => {
|
||||
setDeletionLoading(false);
|
||||
setError(err, "There was an error deleting the transcript");
|
||||
});
|
||||
};
|
||||
|
||||
const handleDeleteTranscript = (transcriptId: string) => (e: any) => {
|
||||
e?.stopPropagation?.();
|
||||
setTranscriptToDeleteId(transcriptId);
|
||||
};
|
||||
|
||||
const handleProcessTranscript = (transcriptId) => (e) => {
|
||||
if (api) {
|
||||
api
|
||||
.v1TranscriptProcess({ transcriptId })
|
||||
.then((result) => {
|
||||
const status = (result as any).status;
|
||||
if (status === "already running") {
|
||||
setError(
|
||||
new Error("Processing is already running, please wait"),
|
||||
"Processing is already running, please wait",
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
setError(err, "There was an error processing the transcript");
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const transcriptToDelete = response?.items?.find(
|
||||
(i) => i.id === transcriptToDeleteId,
|
||||
);
|
||||
const dialogTitle = transcriptToDelete?.title || "Unnamed Transcript";
|
||||
const dialogDate = transcriptToDelete?.created_at
|
||||
? formatLocalDate(transcriptToDelete.created_at)
|
||||
: undefined;
|
||||
const dialogSource = transcriptToDelete
|
||||
? transcriptToDelete.source_kind === "room"
|
||||
? transcriptToDelete.room_name || undefined
|
||||
: transcriptToDelete.source_kind
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -360,15 +168,15 @@ export default function TranscriptBrowser() {
|
||||
>
|
||||
<Heading size="lg">
|
||||
{userName ? `${userName}'s Transcriptions` : "Your Transcriptions"}{" "}
|
||||
{(isLoading || deletionLoading) && <Spinner size="sm" />}
|
||||
{loading || (deletionLoading && <Spinner size="sm" />)}
|
||||
</Heading>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir={{ base: "column", md: "row" }}>
|
||||
<FilterSidebar
|
||||
rooms={rooms}
|
||||
selectedSourceKind={urlSourceKind}
|
||||
selectedRoomId={urlRoomId}
|
||||
selectedSourceKind={selectedSourceKind}
|
||||
selectedRoomId={selectedRoomId}
|
||||
onFilterChange={handleFilterTranscripts}
|
||||
/>
|
||||
|
||||
@@ -380,37 +188,25 @@ export default function TranscriptBrowser() {
|
||||
gap={4}
|
||||
px={{ base: 0, md: 4 }}
|
||||
>
|
||||
<SearchForm
|
||||
<SearchBar onSearch={handleSearch} />
|
||||
<Pagination
|
||||
page={page}
|
||||
setPage={setPage}
|
||||
sourceKind={urlSourceKind}
|
||||
roomId={urlRoomId}
|
||||
searchQuery={urlSearchQuery}
|
||||
setSearchQuery={setUrlSearchQuery}
|
||||
setSourceKind={setUrlSourceKind}
|
||||
setRoomId={setUrlRoomId}
|
||||
rooms={rooms}
|
||||
total={response?.total || 0}
|
||||
size={response?.size || 0}
|
||||
/>
|
||||
|
||||
{totalPages > 1 ? (
|
||||
<Pagination
|
||||
page={page}
|
||||
setPage={setPage}
|
||||
total={totalResults}
|
||||
size={pageSize}
|
||||
/>
|
||||
) : null}
|
||||
|
||||
<TranscriptCards
|
||||
results={results}
|
||||
query={urlSearchQuery}
|
||||
isLoading={isLoading}
|
||||
onDelete={setTranscriptToDeleteId}
|
||||
<TranscriptTable
|
||||
transcripts={response?.items || []}
|
||||
onDelete={handleDeleteTranscript}
|
||||
onReprocess={handleProcessTranscript}
|
||||
loading={loading}
|
||||
/>
|
||||
<TranscriptCards
|
||||
transcripts={response?.items || []}
|
||||
onDelete={handleDeleteTranscript}
|
||||
onReprocess={handleProcessTranscript}
|
||||
loading={loading}
|
||||
/>
|
||||
|
||||
{!isLoading && results.length === 0 && (
|
||||
<EmptyResult searchQuery={urlSearchQuery} />
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import Image from "next/image";
|
||||
import About from "../(aboutAndPrivacy)/about";
|
||||
import Privacy from "../(aboutAndPrivacy)/privacy";
|
||||
import UserInfo from "../(auth)/userInfo";
|
||||
import { RECORD_A_MEETING_URL } from "../api/urls";
|
||||
|
||||
export default async function AppLayout({
|
||||
children,
|
||||
@@ -54,7 +53,7 @@ export default async function AppLayout({
|
||||
{/* Text link on the right */}
|
||||
<Link
|
||||
as={NextLink}
|
||||
href={RECORD_A_MEETING_URL}
|
||||
href="/transcripts/new"
|
||||
className="font-light px-2"
|
||||
>
|
||||
Create
|
||||
|
||||
258
www/app/(app)/rooms/_components/ICSSettings.tsx
Normal file
258
www/app/(app)/rooms/_components/ICSSettings.tsx
Normal file
@@ -0,0 +1,258 @@
|
||||
import {
|
||||
VStack,
|
||||
HStack,
|
||||
Field,
|
||||
Input,
|
||||
Select,
|
||||
Checkbox,
|
||||
Button,
|
||||
Text,
|
||||
Alert,
|
||||
AlertIcon,
|
||||
AlertTitle,
|
||||
Badge,
|
||||
createListCollection,
|
||||
Spinner,
|
||||
} from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { FaSync, FaCheckCircle, FaExclamationCircle } from "react-icons/fa";
|
||||
import useApi from "../../../lib/useApi";
|
||||
|
||||
interface ICSSettingsProps {
|
||||
roomId?: string;
|
||||
roomName?: string;
|
||||
icsUrl?: string;
|
||||
icsEnabled?: boolean;
|
||||
icsFetchInterval?: number;
|
||||
icsLastSync?: string;
|
||||
icsLastEtag?: string;
|
||||
onChange: (settings: Partial<ICSSettingsData>) => void;
|
||||
isOwner?: boolean;
|
||||
}
|
||||
|
||||
export interface ICSSettingsData {
|
||||
ics_url: string;
|
||||
ics_enabled: boolean;
|
||||
ics_fetch_interval: number;
|
||||
}
|
||||
|
||||
const fetchIntervalOptions = [
|
||||
{ label: "1 minute", value: "1" },
|
||||
{ label: "5 minutes", value: "5" },
|
||||
{ label: "10 minutes", value: "10" },
|
||||
{ label: "30 minutes", value: "30" },
|
||||
{ label: "1 hour", value: "60" },
|
||||
];
|
||||
|
||||
export default function ICSSettings({
|
||||
roomId,
|
||||
roomName,
|
||||
icsUrl = "",
|
||||
icsEnabled = false,
|
||||
icsFetchInterval = 5,
|
||||
icsLastSync,
|
||||
icsLastEtag,
|
||||
onChange,
|
||||
isOwner = true,
|
||||
}: ICSSettingsProps) {
|
||||
const [syncStatus, setSyncStatus] = useState<
|
||||
"idle" | "syncing" | "success" | "error"
|
||||
>("idle");
|
||||
const [syncMessage, setSyncMessage] = useState<string>("");
|
||||
const [testResult, setTestResult] = useState<string>("");
|
||||
const api = useApi();
|
||||
|
||||
const fetchIntervalCollection = createListCollection({
|
||||
items: fetchIntervalOptions,
|
||||
});
|
||||
|
||||
const handleTestConnection = async () => {
|
||||
if (!api || !icsUrl || !roomName) return;
|
||||
|
||||
setSyncStatus("syncing");
|
||||
setTestResult("");
|
||||
|
||||
try {
|
||||
// First update the room with the ICS URL
|
||||
await api.v1RoomsPartialUpdate({
|
||||
roomId: roomId || roomName,
|
||||
requestBody: {
|
||||
ics_url: icsUrl,
|
||||
ics_enabled: true,
|
||||
ics_fetch_interval: icsFetchInterval,
|
||||
},
|
||||
});
|
||||
|
||||
// Then trigger a sync
|
||||
const result = await api.v1RoomsTriggerIcsSync({ roomName });
|
||||
|
||||
if (result.status === "success") {
|
||||
setSyncStatus("success");
|
||||
setTestResult(
|
||||
`Successfully synced! Found ${result.events_found} events.`,
|
||||
);
|
||||
} else {
|
||||
setSyncStatus("error");
|
||||
setTestResult(result.error || "Sync failed");
|
||||
}
|
||||
} catch (err: any) {
|
||||
setSyncStatus("error");
|
||||
setTestResult(err.body?.detail || "Failed to test ICS connection");
|
||||
}
|
||||
};
|
||||
|
||||
const handleManualSync = async () => {
|
||||
if (!api || !roomName) return;
|
||||
|
||||
setSyncStatus("syncing");
|
||||
setSyncMessage("");
|
||||
|
||||
try {
|
||||
const result = await api.v1RoomsTriggerIcsSync({ roomName });
|
||||
|
||||
if (result.status === "success") {
|
||||
setSyncStatus("success");
|
||||
setSyncMessage(
|
||||
`Sync complete! Found ${result.events_found} events, ` +
|
||||
`created ${result.events_created}, updated ${result.events_updated}.`,
|
||||
);
|
||||
} else {
|
||||
setSyncStatus("error");
|
||||
setSyncMessage(result.error || "Sync failed");
|
||||
}
|
||||
} catch (err: any) {
|
||||
setSyncStatus("error");
|
||||
setSyncMessage(err.body?.detail || "Failed to sync calendar");
|
||||
}
|
||||
|
||||
// Clear status after 5 seconds
|
||||
setTimeout(() => {
|
||||
setSyncStatus("idle");
|
||||
setSyncMessage("");
|
||||
}, 5000);
|
||||
};
|
||||
|
||||
if (!isOwner) {
|
||||
return null; // ICS settings only visible to room owner
|
||||
}
|
||||
|
||||
return (
|
||||
<VStack spacing={4} align="stretch" mt={6}>
|
||||
<Text fontWeight="semibold" fontSize="lg">
|
||||
Calendar Integration (ICS)
|
||||
</Text>
|
||||
|
||||
<Field.Root>
|
||||
<Checkbox.Root
|
||||
checked={icsEnabled}
|
||||
onCheckedChange={(e) => onChange({ ics_enabled: e.checked })}
|
||||
>
|
||||
<Checkbox.HiddenInput />
|
||||
<Checkbox.Control>
|
||||
<Checkbox.Indicator />
|
||||
</Checkbox.Control>
|
||||
<Checkbox.Label>Enable ICS calendar sync</Checkbox.Label>
|
||||
</Checkbox.Root>
|
||||
</Field.Root>
|
||||
|
||||
{icsEnabled && (
|
||||
<>
|
||||
<Field.Root>
|
||||
<Field.Label>ICS Calendar URL</Field.Label>
|
||||
<Input
|
||||
placeholder="https://calendar.google.com/calendar/ical/..."
|
||||
value={icsUrl}
|
||||
onChange={(e) => onChange({ ics_url: e.target.value })}
|
||||
/>
|
||||
<Field.HelperText>
|
||||
Enter the ICS URL from Google Calendar, Outlook, or other calendar
|
||||
services
|
||||
</Field.HelperText>
|
||||
</Field.Root>
|
||||
|
||||
<Field.Root>
|
||||
<Field.Label>Sync Interval</Field.Label>
|
||||
<Select.Root
|
||||
collection={fetchIntervalCollection}
|
||||
value={[icsFetchInterval.toString()]}
|
||||
onValueChange={(details) => {
|
||||
const value = parseInt(details.value[0]);
|
||||
onChange({ ics_fetch_interval: value });
|
||||
}}
|
||||
>
|
||||
<Select.Trigger>
|
||||
<Select.ValueText />
|
||||
</Select.Trigger>
|
||||
<Select.Content>
|
||||
{fetchIntervalOptions.map((option) => (
|
||||
<Select.Item key={option.value} item={option}>
|
||||
{option.label}
|
||||
</Select.Item>
|
||||
))}
|
||||
</Select.Content>
|
||||
</Select.Root>
|
||||
<Field.HelperText>
|
||||
How often to check for calendar updates
|
||||
</Field.HelperText>
|
||||
</Field.Root>
|
||||
|
||||
{icsUrl && (
|
||||
<HStack spacing={3}>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={handleTestConnection}
|
||||
disabled={syncStatus === "syncing"}
|
||||
leftIcon={
|
||||
syncStatus === "syncing" ? <Spinner size="sm" /> : undefined
|
||||
}
|
||||
>
|
||||
Test Connection
|
||||
</Button>
|
||||
|
||||
{roomName && icsLastSync && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={handleManualSync}
|
||||
disabled={syncStatus === "syncing"}
|
||||
leftIcon={<FaSync />}
|
||||
>
|
||||
Sync Now
|
||||
</Button>
|
||||
)}
|
||||
</HStack>
|
||||
)}
|
||||
|
||||
{testResult && (
|
||||
<Alert status={syncStatus === "success" ? "success" : "error"}>
|
||||
<AlertIcon />
|
||||
<Text fontSize="sm">{testResult}</Text>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{syncMessage && (
|
||||
<Alert status={syncStatus === "success" ? "success" : "error"}>
|
||||
<AlertIcon />
|
||||
<Text fontSize="sm">{syncMessage}</Text>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{icsLastSync && (
|
||||
<HStack spacing={4} fontSize="sm" color="gray.600">
|
||||
<HStack>
|
||||
<FaCheckCircle color="green" />
|
||||
<Text>Last sync: {new Date(icsLastSync).toLocaleString()}</Text>
|
||||
</HStack>
|
||||
{icsLastEtag && (
|
||||
<Badge colorScheme="blue" fontSize="xs">
|
||||
ETag: {icsLastEtag.slice(0, 8)}...
|
||||
</Badge>
|
||||
)}
|
||||
</HStack>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import useApi from "../../lib/useApi";
|
||||
import useRoomList from "./useRoomList";
|
||||
import { ApiError, Room } from "../../api";
|
||||
import { RoomList } from "./_components/RoomList";
|
||||
import { PaginationPage } from "../browse/_components/Pagination";
|
||||
import ICSSettings from "./_components/ICSSettings";
|
||||
|
||||
interface SelectOption {
|
||||
label: string;
|
||||
@@ -55,6 +55,9 @@ const roomInitialState = {
|
||||
recordingType: "cloud",
|
||||
recordingTrigger: "automatic-2nd-participant",
|
||||
isShared: false,
|
||||
icsUrl: "",
|
||||
icsEnabled: false,
|
||||
icsFetchInterval: 5,
|
||||
};
|
||||
|
||||
export default function RoomsList() {
|
||||
@@ -76,9 +79,8 @@ export default function RoomsList() {
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [editRoomId, setEditRoomId] = useState("");
|
||||
const api = useApi();
|
||||
// TODO seems to be no setPage calls
|
||||
const [page, setPage] = useState<number>(1);
|
||||
const { loading, response, refetch } = useRoomList(PaginationPage(page));
|
||||
const { loading, response, refetch } = useRoomList(page);
|
||||
const [streams, setStreams] = useState<Stream[]>([]);
|
||||
const [topics, setTopics] = useState<Topic[]>([]);
|
||||
const [nameError, setNameError] = useState("");
|
||||
@@ -172,6 +174,9 @@ export default function RoomsList() {
|
||||
recording_type: room.recordingType,
|
||||
recording_trigger: room.recordingTrigger,
|
||||
is_shared: room.isShared,
|
||||
ics_url: room.icsUrl,
|
||||
ics_enabled: room.icsEnabled,
|
||||
ics_fetch_interval: room.icsFetchInterval,
|
||||
};
|
||||
|
||||
if (isEditing) {
|
||||
@@ -217,6 +222,9 @@ export default function RoomsList() {
|
||||
recordingType: roomData.recording_type,
|
||||
recordingTrigger: roomData.recording_trigger,
|
||||
isShared: roomData.is_shared,
|
||||
icsUrl: roomData.ics_url || "",
|
||||
icsEnabled: roomData.ics_enabled || false,
|
||||
icsFetchInterval: roomData.ics_fetch_interval || 5,
|
||||
});
|
||||
setEditRoomId(roomId);
|
||||
setIsEditing(true);
|
||||
@@ -555,6 +563,32 @@ export default function RoomsList() {
|
||||
<Checkbox.Label>Shared room</Checkbox.Label>
|
||||
</Checkbox.Root>
|
||||
</Field.Root>
|
||||
|
||||
<ICSSettings
|
||||
roomId={editRoomId}
|
||||
roomName={room.name}
|
||||
icsUrl={room.icsUrl}
|
||||
icsEnabled={room.icsEnabled}
|
||||
icsFetchInterval={room.icsFetchInterval}
|
||||
onChange={(settings) => {
|
||||
setRoom({
|
||||
...room,
|
||||
icsUrl:
|
||||
settings.ics_url !== undefined
|
||||
? settings.ics_url
|
||||
: room.icsUrl,
|
||||
icsEnabled:
|
||||
settings.ics_enabled !== undefined
|
||||
? settings.ics_enabled
|
||||
: room.icsEnabled,
|
||||
icsFetchInterval:
|
||||
settings.ics_fetch_interval !== undefined
|
||||
? settings.ics_fetch_interval
|
||||
: room.icsFetchInterval,
|
||||
});
|
||||
}}
|
||||
isOwner={true}
|
||||
/>
|
||||
</Dialog.Body>
|
||||
<Dialog.Footer>
|
||||
<Button variant="ghost" onClick={onClose}>
|
||||
|
||||
@@ -2,7 +2,6 @@ import { useEffect, useState } from "react";
|
||||
import { useError } from "../../(errors)/errorContext";
|
||||
import useApi from "../../lib/useApi";
|
||||
import { Page_Room_ } from "../../api";
|
||||
import { PaginationPage } from "../browse/_components/Pagination";
|
||||
|
||||
type RoomList = {
|
||||
response: Page_Room_ | null;
|
||||
@@ -12,7 +11,7 @@ type RoomList = {
|
||||
};
|
||||
|
||||
//always protected
|
||||
const useRoomList = (page: PaginationPage): RoomList => {
|
||||
const useRoomList = (page: number): RoomList => {
|
||||
const [response, setResponse] = useState<Page_Room_ | null>(null);
|
||||
const [loading, setLoading] = useState<boolean>(true);
|
||||
const [error, setErrorState] = useState<Error | null>(null);
|
||||
|
||||
32
www/app/(app)/transcripts/mockTopics.json
Normal file
32
www/app/(app)/transcripts/mockTopics.json
Normal file
@@ -0,0 +1,32 @@
|
||||
[
|
||||
{
|
||||
"id": "27c07e49-d7a3-4b86-905c-f1a047366f91",
|
||||
"title": "Issue one",
|
||||
"summary": "The team discusses the first issue in the list",
|
||||
"timestamp": 0.0,
|
||||
"transcript": "",
|
||||
"duration": 33,
|
||||
"segments": [
|
||||
{
|
||||
"text": "Let's start with issue one, Alice you've been working on that, can you give an update ?",
|
||||
"start": 0.0,
|
||||
"speaker": 0
|
||||
},
|
||||
{
|
||||
"text": "Yes, I've run into an issue with the task system but Bob helped me out and I have a POC ready, should I present it now ?",
|
||||
"start": 0.38,
|
||||
"speaker": 1
|
||||
},
|
||||
{
|
||||
"text": "Yeah, I had to modify the task system because it didn't account for incoming blobs",
|
||||
"start": 4.5,
|
||||
"speaker": 2
|
||||
},
|
||||
{
|
||||
"text": "Cool, yeah lets see it",
|
||||
"start": 5.96,
|
||||
"speaker": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -11,7 +11,6 @@ import useWebRTC from "./useWebRTC";
|
||||
import useAudioDevice from "./useAudioDevice";
|
||||
import { Box, Flex, IconButton, Menu, RadioGroup } from "@chakra-ui/react";
|
||||
import { LuScreenShare, LuMic, LuPlay, LuCircleStop } from "react-icons/lu";
|
||||
import { RECORD_A_MEETING_URL } from "../../api/urls";
|
||||
|
||||
type RecorderProps = {
|
||||
transcriptId: string;
|
||||
@@ -47,7 +46,7 @@ export default function Recorder(props: RecorderProps) {
|
||||
location.href = "";
|
||||
break;
|
||||
case ",":
|
||||
location.href = RECORD_A_MEETING_URL;
|
||||
location.href = "/transcripts/new";
|
||||
break;
|
||||
case "!":
|
||||
if (record.isRecording()) return;
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
// this hook is not great, we want to substitute it with a proper state management solution that is also not re-invention
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { SearchResult, SourceKind } from "../../api";
|
||||
import useApi from "../../lib/useApi";
|
||||
import {
|
||||
PaginationPage,
|
||||
paginationPageTo0Based,
|
||||
} from "../browse/_components/Pagination";
|
||||
|
||||
interface SearchFilters {
|
||||
roomIds: readonly string[] | null;
|
||||
sourceKind: SourceKind | null;
|
||||
}
|
||||
|
||||
const EMPTY_SEARCH_FILTERS: SearchFilters = {
|
||||
roomIds: null,
|
||||
sourceKind: null,
|
||||
};
|
||||
|
||||
type UseSearchTranscriptsOptions = {
|
||||
pageSize: number;
|
||||
page: PaginationPage;
|
||||
};
|
||||
|
||||
interface UseSearchTranscriptsReturn {
|
||||
results: SearchResult[];
|
||||
totalCount: number;
|
||||
isLoading: boolean;
|
||||
error: unknown;
|
||||
reload: () => void;
|
||||
}
|
||||
|
||||
function hashEffectFilters(filters: SearchFilters): string {
|
||||
return JSON.stringify(filters);
|
||||
}
|
||||
|
||||
export function useSearchTranscripts(
|
||||
query: string = "",
|
||||
filters: SearchFilters = EMPTY_SEARCH_FILTERS,
|
||||
options: UseSearchTranscriptsOptions = {
|
||||
pageSize: 20,
|
||||
page: PaginationPage(1),
|
||||
},
|
||||
): UseSearchTranscriptsReturn {
|
||||
const { pageSize, page } = options;
|
||||
|
||||
const [reloadCount, setReloadCount] = useState(0);
|
||||
|
||||
const api = useApi();
|
||||
const abortControllerRef = useRef<AbortController>();
|
||||
|
||||
const [data, setData] = useState<{ results: SearchResult[]; total: number }>({
|
||||
results: [],
|
||||
total: 0,
|
||||
});
|
||||
const [error, setError] = useState<any>();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const filterHash = hashEffectFilters(filters);
|
||||
|
||||
useEffect(() => {
|
||||
if (!api) {
|
||||
setData({ results: [], total: 0 });
|
||||
setError(undefined);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
const performSearch = async () => {
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const response = await api.v1TranscriptsSearch({
|
||||
q: query || "",
|
||||
limit: pageSize,
|
||||
offset: paginationPageTo0Based(page) * pageSize,
|
||||
roomId: filters.roomIds?.[0],
|
||||
sourceKind: filters.sourceKind || undefined,
|
||||
});
|
||||
|
||||
if (abortController.signal.aborted) return;
|
||||
setData(response);
|
||||
setError(undefined);
|
||||
} catch (err: unknown) {
|
||||
if ((err as Error).name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
if (abortController.signal.aborted) {
|
||||
console.error("Aborted search but error", err);
|
||||
return;
|
||||
}
|
||||
|
||||
setError(err);
|
||||
} finally {
|
||||
if (!abortController.signal.aborted) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
performSearch().then(() => {});
|
||||
|
||||
return () => {
|
||||
abortController.abort();
|
||||
};
|
||||
}, [api, query, page, filterHash, pageSize, reloadCount]);
|
||||
|
||||
return {
|
||||
results: data.results,
|
||||
totalCount: data.total,
|
||||
isLoading,
|
||||
error,
|
||||
reload: () => setReloadCount(reloadCount + 1),
|
||||
};
|
||||
}
|
||||
59
www/app/(app)/transcripts/useTranscriptList.ts
Normal file
59
www/app/(app)/transcripts/useTranscriptList.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { useError } from "../../(errors)/errorContext";
|
||||
import useApi from "../../lib/useApi";
|
||||
import { Page_GetTranscriptMinimal_, SourceKind } from "../../api";
|
||||
|
||||
type TranscriptList = {
|
||||
response: Page_GetTranscriptMinimal_ | null;
|
||||
loading: boolean;
|
||||
error: Error | null;
|
||||
refetch: () => void;
|
||||
};
|
||||
|
||||
const useTranscriptList = (
|
||||
page: number,
|
||||
sourceKind: SourceKind | null,
|
||||
roomId: string | null,
|
||||
searchTerm: string | null,
|
||||
): TranscriptList => {
|
||||
const [response, setResponse] = useState<Page_GetTranscriptMinimal_ | null>(
|
||||
null,
|
||||
);
|
||||
const [loading, setLoading] = useState<boolean>(true);
|
||||
const [error, setErrorState] = useState<Error | null>(null);
|
||||
const { setError } = useError();
|
||||
const api = useApi();
|
||||
const [refetchCount, setRefetchCount] = useState(0);
|
||||
|
||||
const refetch = () => {
|
||||
setLoading(true);
|
||||
setRefetchCount(refetchCount + 1);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!api) return;
|
||||
setLoading(true);
|
||||
api
|
||||
.v1TranscriptsList({
|
||||
page,
|
||||
sourceKind,
|
||||
roomId,
|
||||
searchTerm,
|
||||
size: 10,
|
||||
})
|
||||
.then((response) => {
|
||||
setResponse(response);
|
||||
setLoading(false);
|
||||
})
|
||||
.catch((err) => {
|
||||
setResponse(null);
|
||||
setLoading(false);
|
||||
setError(err);
|
||||
setErrorState(err);
|
||||
});
|
||||
}, [api, page, refetchCount, roomId, searchTerm, sourceKind]);
|
||||
|
||||
return { response, loading, error, refetch };
|
||||
};
|
||||
|
||||
export default useTranscriptList;
|
||||
203
www/app/[roomName]/MeetingInfo.tsx
Normal file
203
www/app/[roomName]/MeetingInfo.tsx
Normal file
@@ -0,0 +1,203 @@
|
||||
import {
|
||||
Box,
|
||||
VStack,
|
||||
HStack,
|
||||
Text,
|
||||
Badge,
|
||||
Icon,
|
||||
Divider,
|
||||
} from "@chakra-ui/react";
|
||||
import { FaCalendarAlt, FaUsers, FaClock, FaInfoCircle } from "react-icons/fa";
|
||||
import { Meeting } from "../api";
|
||||
|
||||
interface MeetingInfoProps {
|
||||
meeting: Meeting;
|
||||
isOwner: boolean;
|
||||
}
|
||||
|
||||
export default function MeetingInfo({ meeting, isOwner }: MeetingInfoProps) {
|
||||
const formatDuration = (start: string | Date, end: string | Date) => {
|
||||
const startDate = new Date(start);
|
||||
const endDate = new Date(end);
|
||||
const now = new Date();
|
||||
|
||||
// If meeting hasn't started yet
|
||||
if (startDate > now) {
|
||||
return `Scheduled for ${startDate.toLocaleTimeString()}`;
|
||||
}
|
||||
|
||||
// Calculate duration
|
||||
const durationMs = now.getTime() - startDate.getTime();
|
||||
const hours = Math.floor(durationMs / (1000 * 60 * 60));
|
||||
const minutes = Math.floor((durationMs % (1000 * 60 * 60)) / (1000 * 60));
|
||||
|
||||
if (hours > 0) {
|
||||
return `${hours}h ${minutes}m`;
|
||||
}
|
||||
return `${minutes} minutes`;
|
||||
};
|
||||
|
||||
const isCalendarMeeting = !!meeting.calendar_event_id;
|
||||
const metadata = meeting.calendar_metadata;
|
||||
|
||||
return (
|
||||
<Box
|
||||
position="absolute"
|
||||
top="56px"
|
||||
right="8px"
|
||||
bg="white"
|
||||
borderRadius="md"
|
||||
boxShadow="lg"
|
||||
p={4}
|
||||
maxW="300px"
|
||||
zIndex={999}
|
||||
>
|
||||
<VStack align="stretch" spacing={3}>
|
||||
{/* Meeting Title */}
|
||||
<HStack>
|
||||
<Icon
|
||||
as={isCalendarMeeting ? FaCalendarAlt : FaInfoCircle}
|
||||
color="blue.500"
|
||||
/>
|
||||
<Text fontWeight="semibold" fontSize="md">
|
||||
{metadata?.title ||
|
||||
(isCalendarMeeting ? "Calendar Meeting" : "Unscheduled Meeting")}
|
||||
</Text>
|
||||
</HStack>
|
||||
|
||||
{/* Meeting Status */}
|
||||
<HStack spacing={2}>
|
||||
{meeting.is_active && (
|
||||
<Badge colorScheme="green" fontSize="xs">
|
||||
Active
|
||||
</Badge>
|
||||
)}
|
||||
{isCalendarMeeting && (
|
||||
<Badge colorScheme="blue" fontSize="xs">
|
||||
Calendar
|
||||
</Badge>
|
||||
)}
|
||||
{meeting.is_locked && (
|
||||
<Badge colorScheme="orange" fontSize="xs">
|
||||
Locked
|
||||
</Badge>
|
||||
)}
|
||||
</HStack>
|
||||
|
||||
<Divider />
|
||||
|
||||
{/* Meeting Details */}
|
||||
<VStack align="stretch" spacing={2} fontSize="sm">
|
||||
{/* Participants */}
|
||||
<HStack>
|
||||
<Icon as={FaUsers} color="gray.500" />
|
||||
<Text>
|
||||
{meeting.num_clients}{" "}
|
||||
{meeting.num_clients === 1 ? "participant" : "participants"}
|
||||
</Text>
|
||||
</HStack>
|
||||
|
||||
{/* Duration */}
|
||||
<HStack>
|
||||
<Icon as={FaClock} color="gray.500" />
|
||||
<Text>
|
||||
Duration: {formatDuration(meeting.start_date, meeting.end_date)}
|
||||
</Text>
|
||||
</HStack>
|
||||
|
||||
{/* Calendar Description (Owner only) */}
|
||||
{isOwner && metadata?.description && (
|
||||
<>
|
||||
<Divider />
|
||||
<Box>
|
||||
<Text
|
||||
fontWeight="semibold"
|
||||
fontSize="xs"
|
||||
color="gray.600"
|
||||
mb={1}
|
||||
>
|
||||
Description
|
||||
</Text>
|
||||
<Text fontSize="xs" color="gray.700">
|
||||
{metadata.description}
|
||||
</Text>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Attendees (Owner only) */}
|
||||
{isOwner && metadata?.attendees && metadata.attendees.length > 0 && (
|
||||
<>
|
||||
<Divider />
|
||||
<Box>
|
||||
<Text
|
||||
fontWeight="semibold"
|
||||
fontSize="xs"
|
||||
color="gray.600"
|
||||
mb={1}
|
||||
>
|
||||
Invited Attendees ({metadata.attendees.length})
|
||||
</Text>
|
||||
<VStack align="stretch" spacing={1}>
|
||||
{metadata.attendees
|
||||
.slice(0, 5)
|
||||
.map((attendee: any, idx: number) => (
|
||||
<HStack key={idx} fontSize="xs">
|
||||
<Badge
|
||||
colorScheme={
|
||||
attendee.status === "ACCEPTED"
|
||||
? "green"
|
||||
: attendee.status === "DECLINED"
|
||||
? "red"
|
||||
: attendee.status === "TENTATIVE"
|
||||
? "yellow"
|
||||
: "gray"
|
||||
}
|
||||
fontSize="xs"
|
||||
size="sm"
|
||||
>
|
||||
{attendee.status?.charAt(0) || "?"}
|
||||
</Badge>
|
||||
<Text color="gray.700" isTruncated>
|
||||
{attendee.name || attendee.email}
|
||||
</Text>
|
||||
</HStack>
|
||||
))}
|
||||
{metadata.attendees.length > 5 && (
|
||||
<Text fontSize="xs" color="gray.500" fontStyle="italic">
|
||||
+{metadata.attendees.length - 5} more
|
||||
</Text>
|
||||
)}
|
||||
</VStack>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Recording Info */}
|
||||
{meeting.recording_type !== "none" && (
|
||||
<>
|
||||
<Divider />
|
||||
<HStack fontSize="xs">
|
||||
<Badge colorScheme="red" fontSize="xs">
|
||||
Recording
|
||||
</Badge>
|
||||
<Text color="gray.600">
|
||||
{meeting.recording_type === "cloud" ? "Cloud" : "Local"}
|
||||
{meeting.recording_trigger !== "none" &&
|
||||
` (${meeting.recording_trigger})`}
|
||||
</Text>
|
||||
</HStack>
|
||||
</>
|
||||
)}
|
||||
</VStack>
|
||||
|
||||
{/* Meeting Times */}
|
||||
<Divider />
|
||||
<VStack align="stretch" spacing={1} fontSize="xs" color="gray.600">
|
||||
<Text>Start: {new Date(meeting.start_date).toLocaleString()}</Text>
|
||||
<Text>End: {new Date(meeting.end_date).toLocaleString()}</Text>
|
||||
</VStack>
|
||||
</VStack>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
335
www/app/[roomName]/MeetingSelection.tsx
Normal file
335
www/app/[roomName]/MeetingSelection.tsx
Normal file
@@ -0,0 +1,335 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Box,
|
||||
VStack,
|
||||
HStack,
|
||||
Text,
|
||||
Button,
|
||||
Spinner,
|
||||
Card,
|
||||
CardBody,
|
||||
CardHeader,
|
||||
Badge,
|
||||
Divider,
|
||||
Icon,
|
||||
Alert,
|
||||
AlertIcon,
|
||||
AlertTitle,
|
||||
AlertDescription,
|
||||
} from "@chakra-ui/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { FaUsers, FaClock, FaCalendarAlt, FaPlus } from "react-icons/fa";
|
||||
import { Meeting, CalendarEventResponse } from "../api";
|
||||
import useApi from "../lib/useApi";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
interface MeetingSelectionProps {
|
||||
roomName: string;
|
||||
isOwner: boolean;
|
||||
onMeetingSelect: (meeting: Meeting) => void;
|
||||
onCreateUnscheduled: () => void;
|
||||
}
|
||||
|
||||
const formatDateTime = (date: string | Date) => {
|
||||
const d = new Date(date);
|
||||
return d.toLocaleString("en-US", {
|
||||
month: "short",
|
||||
day: "numeric",
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
});
|
||||
};
|
||||
|
||||
const formatCountdown = (startTime: string | Date) => {
|
||||
const now = new Date();
|
||||
const start = new Date(startTime);
|
||||
const diff = start.getTime() - now.getTime();
|
||||
|
||||
if (diff <= 0) return "Starting now";
|
||||
|
||||
const minutes = Math.floor(diff / 60000);
|
||||
const hours = Math.floor(minutes / 60);
|
||||
|
||||
if (hours > 0) {
|
||||
return `Starts in ${hours}h ${minutes % 60}m`;
|
||||
}
|
||||
return `Starts in ${minutes} minutes`;
|
||||
};
|
||||
|
||||
export default function MeetingSelection({
|
||||
roomName,
|
||||
isOwner,
|
||||
onMeetingSelect,
|
||||
onCreateUnscheduled,
|
||||
}: MeetingSelectionProps) {
|
||||
const [activeMeetings, setActiveMeetings] = useState<Meeting[]>([]);
|
||||
const [upcomingEvents, setUpcomingEvents] = useState<CalendarEventResponse[]>(
|
||||
[],
|
||||
);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const api = useApi();
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
if (!api) return;
|
||||
|
||||
const fetchMeetings = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
|
||||
// Fetch active meetings
|
||||
const active = await api.v1RoomsListActiveMeetings({ roomName });
|
||||
setActiveMeetings(active);
|
||||
|
||||
// Fetch upcoming calendar events (30 min ahead)
|
||||
const upcoming = await api.v1RoomsListUpcomingMeetings({
|
||||
roomName,
|
||||
minutesAhead: 30,
|
||||
});
|
||||
setUpcomingEvents(upcoming);
|
||||
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
console.error("Failed to fetch meetings:", err);
|
||||
setError("Failed to load meetings. Please try again.");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchMeetings();
|
||||
|
||||
// Refresh every 30 seconds
|
||||
const interval = setInterval(fetchMeetings, 30000);
|
||||
return () => clearInterval(interval);
|
||||
}, [api, roomName]);
|
||||
|
||||
const handleJoinMeeting = async (meetingId: string) => {
|
||||
if (!api) return;
|
||||
|
||||
try {
|
||||
const meeting = await api.v1RoomsJoinMeeting({
|
||||
roomName,
|
||||
meetingId,
|
||||
});
|
||||
onMeetingSelect(meeting);
|
||||
} catch (err) {
|
||||
console.error("Failed to join meeting:", err);
|
||||
setError("Failed to join meeting. Please try again.");
|
||||
}
|
||||
};
|
||||
|
||||
const handleJoinUpcoming = (event: CalendarEventResponse) => {
|
||||
// Navigate to waiting page with event info
|
||||
router.push(`/room/${roomName}/wait?eventId=${event.id}`);
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<Box p={8} textAlign="center">
|
||||
<Spinner size="lg" color="blue.500" />
|
||||
<Text mt={4}>Loading meetings...</Text>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<Alert status="error" borderRadius="md">
|
||||
<AlertIcon />
|
||||
<AlertTitle>Error</AlertTitle>
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<VStack spacing={6} align="stretch" p={6}>
|
||||
<Box>
|
||||
<Text fontSize="2xl" fontWeight="bold" mb={4}>
|
||||
Select a Meeting
|
||||
</Text>
|
||||
|
||||
{/* Active Meetings */}
|
||||
{activeMeetings.length > 0 && (
|
||||
<>
|
||||
<Text fontSize="lg" fontWeight="semibold" mb={3}>
|
||||
Active Meetings
|
||||
</Text>
|
||||
<VStack spacing={3} mb={6}>
|
||||
{activeMeetings.map((meeting) => (
|
||||
<Card key={meeting.id} width="100%" variant="outline">
|
||||
<CardBody>
|
||||
<HStack justify="space-between" align="start">
|
||||
<VStack align="start" spacing={2} flex={1}>
|
||||
<HStack>
|
||||
<Icon as={FaCalendarAlt} color="blue.500" />
|
||||
<Text fontWeight="semibold">
|
||||
{meeting.calendar_metadata?.title || "Meeting"}
|
||||
</Text>
|
||||
</HStack>
|
||||
|
||||
{isOwner && meeting.calendar_metadata?.description && (
|
||||
<Text fontSize="sm" color="gray.600">
|
||||
{meeting.calendar_metadata.description}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<HStack spacing={4} fontSize="sm" color="gray.500">
|
||||
<HStack>
|
||||
<Icon as={FaUsers} />
|
||||
<Text>{meeting.num_clients} participants</Text>
|
||||
</HStack>
|
||||
<HStack>
|
||||
<Icon as={FaClock} />
|
||||
<Text>
|
||||
Started {formatDateTime(meeting.start_date)}
|
||||
</Text>
|
||||
</HStack>
|
||||
</HStack>
|
||||
|
||||
{isOwner && meeting.calendar_metadata?.attendees && (
|
||||
<HStack spacing={2} flexWrap="wrap">
|
||||
{meeting.calendar_metadata.attendees
|
||||
.slice(0, 3)
|
||||
.map((attendee: any, idx: number) => (
|
||||
<Badge
|
||||
key={idx}
|
||||
colorScheme="green"
|
||||
fontSize="xs"
|
||||
>
|
||||
{attendee.name || attendee.email}
|
||||
</Badge>
|
||||
))}
|
||||
{meeting.calendar_metadata.attendees.length > 3 && (
|
||||
<Badge colorScheme="gray" fontSize="xs">
|
||||
+
|
||||
{meeting.calendar_metadata.attendees.length - 3}{" "}
|
||||
more
|
||||
</Badge>
|
||||
)}
|
||||
</HStack>
|
||||
)}
|
||||
</VStack>
|
||||
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
size="md"
|
||||
onClick={() => handleJoinMeeting(meeting.id)}
|
||||
>
|
||||
Join Now
|
||||
</Button>
|
||||
</HStack>
|
||||
</CardBody>
|
||||
</Card>
|
||||
))}
|
||||
</VStack>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Upcoming Meetings */}
|
||||
{upcomingEvents.length > 0 && (
|
||||
<>
|
||||
<Text fontSize="lg" fontWeight="semibold" mb={3}>
|
||||
Upcoming Meetings
|
||||
</Text>
|
||||
<VStack spacing={3} mb={6}>
|
||||
{upcomingEvents.map((event) => (
|
||||
<Card
|
||||
key={event.id}
|
||||
width="100%"
|
||||
variant="outline"
|
||||
bg="gray.50"
|
||||
>
|
||||
<CardBody>
|
||||
<HStack justify="space-between" align="start">
|
||||
<VStack align="start" spacing={2} flex={1}>
|
||||
<HStack>
|
||||
<Icon as={FaCalendarAlt} color="orange.500" />
|
||||
<Text fontWeight="semibold">
|
||||
{event.title || "Scheduled Meeting"}
|
||||
</Text>
|
||||
<Badge colorScheme="orange" fontSize="xs">
|
||||
{formatCountdown(event.start_time)}
|
||||
</Badge>
|
||||
</HStack>
|
||||
|
||||
{isOwner && event.description && (
|
||||
<Text fontSize="sm" color="gray.600">
|
||||
{event.description}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<HStack spacing={4} fontSize="sm" color="gray.500">
|
||||
<Text>
|
||||
{formatDateTime(event.start_time)} -{" "}
|
||||
{formatDateTime(event.end_time)}
|
||||
</Text>
|
||||
</HStack>
|
||||
|
||||
{isOwner && event.attendees && (
|
||||
<HStack spacing={2} flexWrap="wrap">
|
||||
{event.attendees
|
||||
.slice(0, 3)
|
||||
.map((attendee: any, idx: number) => (
|
||||
<Badge
|
||||
key={idx}
|
||||
colorScheme="purple"
|
||||
fontSize="xs"
|
||||
>
|
||||
{attendee.name || attendee.email}
|
||||
</Badge>
|
||||
))}
|
||||
{event.attendees.length > 3 && (
|
||||
<Badge colorScheme="gray" fontSize="xs">
|
||||
+{event.attendees.length - 3} more
|
||||
</Badge>
|
||||
)}
|
||||
</HStack>
|
||||
)}
|
||||
</VStack>
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
colorScheme="orange"
|
||||
size="md"
|
||||
onClick={() => handleJoinUpcoming(event)}
|
||||
>
|
||||
Join Early
|
||||
</Button>
|
||||
</HStack>
|
||||
</CardBody>
|
||||
</Card>
|
||||
))}
|
||||
</VStack>
|
||||
</>
|
||||
)}
|
||||
|
||||
<Divider my={6} />
|
||||
|
||||
{/* Create Unscheduled Meeting */}
|
||||
<Card width="100%" variant="filled" bg="gray.100">
|
||||
<CardBody>
|
||||
<HStack justify="space-between" align="center">
|
||||
<VStack align="start" spacing={1}>
|
||||
<Text fontWeight="semibold">Start an Unscheduled Meeting</Text>
|
||||
<Text fontSize="sm" color="gray.600">
|
||||
Create a new meeting room that's not on the calendar
|
||||
</Text>
|
||||
</VStack>
|
||||
<Button
|
||||
leftIcon={<FaPlus />}
|
||||
colorScheme="green"
|
||||
onClick={onCreateUnscheduled}
|
||||
>
|
||||
Create Meeting
|
||||
</Button>
|
||||
</HStack>
|
||||
</CardBody>
|
||||
</Card>
|
||||
</Box>
|
||||
</VStack>
|
||||
);
|
||||
}
|
||||
@@ -24,8 +24,9 @@ import { notFound } from "next/navigation";
|
||||
import useSessionStatus from "../lib/useSessionStatus";
|
||||
import { useRecordingConsent } from "../recordingConsentContext";
|
||||
import useApi from "../lib/useApi";
|
||||
import { Meeting } from "../api";
|
||||
import { FaBars } from "react-icons/fa6";
|
||||
import { Meeting, Room } from "../api";
|
||||
import { FaBars, FaInfoCircle } from "react-icons/fa6";
|
||||
import MeetingInfo from "./MeetingInfo";
|
||||
|
||||
export type RoomDetails = {
|
||||
params: {
|
||||
@@ -254,7 +255,10 @@ export default function Room(details: RoomDetails) {
|
||||
const roomName = details.params.roomName;
|
||||
const meeting = useRoomMeeting(roomName);
|
||||
const router = useRouter();
|
||||
const { isLoading, isAuthenticated } = useSessionStatus();
|
||||
const { isLoading, isAuthenticated, data: session } = useSessionStatus();
|
||||
const [showMeetingInfo, setShowMeetingInfo] = useState(false);
|
||||
const [room, setRoom] = useState<Room | null>(null);
|
||||
const api = useApi();
|
||||
|
||||
const roomUrl = meeting?.response?.host_room_url
|
||||
? meeting?.response?.host_room_url
|
||||
@@ -268,6 +272,15 @@ export default function Room(details: RoomDetails) {
|
||||
router.push("/browse");
|
||||
}, [router]);
|
||||
|
||||
// Fetch room details
|
||||
useEffect(() => {
|
||||
if (!api || !roomName) return;
|
||||
|
||||
api.v1RoomsRetrieve({ roomName }).then(setRoom).catch(console.error);
|
||||
}, [api, roomName]);
|
||||
|
||||
const isOwner = session?.user?.id === room?.user_id;
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
!isLoading &&
|
||||
@@ -319,6 +332,25 @@ export default function Room(details: RoomDetails) {
|
||||
wherebyRef={wherebyRef}
|
||||
/>
|
||||
)}
|
||||
{meeting?.response && (
|
||||
<>
|
||||
<Button
|
||||
position="absolute"
|
||||
top="56px"
|
||||
right={showMeetingInfo ? "320px" : "8px"}
|
||||
zIndex={1000}
|
||||
colorPalette="blue"
|
||||
size="sm"
|
||||
onClick={() => setShowMeetingInfo(!showMeetingInfo)}
|
||||
leftIcon={<Icon as={FaInfoCircle} />}
|
||||
>
|
||||
Meeting Info
|
||||
</Button>
|
||||
{showMeetingInfo && (
|
||||
<MeetingInfo meeting={meeting.response} isOwner={isOwner} />
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -40,6 +40,20 @@ const useRoomMeeting = (
|
||||
useEffect(() => {
|
||||
if (!roomName || !api) return;
|
||||
|
||||
// Check if meeting was pre-selected from meeting selection page
|
||||
const storedMeeting = sessionStorage.getItem(`meeting_${roomName}`);
|
||||
if (storedMeeting) {
|
||||
try {
|
||||
const meeting = JSON.parse(storedMeeting);
|
||||
sessionStorage.removeItem(`meeting_${roomName}`); // Clean up
|
||||
setResponse(meeting);
|
||||
setLoading(false);
|
||||
return;
|
||||
} catch (e) {
|
||||
console.error("Failed to parse stored meeting:", e);
|
||||
}
|
||||
}
|
||||
|
||||
if (!response) {
|
||||
setLoading(true);
|
||||
}
|
||||
|
||||
@@ -30,6 +30,108 @@ export const $Body_transcript_record_upload_v1_transcripts__transcript_id__recor
|
||||
"Body_transcript_record_upload_v1_transcripts__transcript_id__record_upload_post",
|
||||
} as const;
|
||||
|
||||
export const $CalendarEventResponse = {
|
||||
properties: {
|
||||
id: {
|
||||
type: "string",
|
||||
title: "Id",
|
||||
},
|
||||
room_id: {
|
||||
type: "string",
|
||||
title: "Room Id",
|
||||
},
|
||||
ics_uid: {
|
||||
type: "string",
|
||||
title: "Ics Uid",
|
||||
},
|
||||
title: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Title",
|
||||
},
|
||||
description: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Description",
|
||||
},
|
||||
start_time: {
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
title: "Start Time",
|
||||
},
|
||||
end_time: {
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
title: "End Time",
|
||||
},
|
||||
attendees: {
|
||||
anyOf: [
|
||||
{
|
||||
items: {
|
||||
additionalProperties: true,
|
||||
type: "object",
|
||||
},
|
||||
type: "array",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Attendees",
|
||||
},
|
||||
location: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Location",
|
||||
},
|
||||
last_synced: {
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
title: "Last Synced",
|
||||
},
|
||||
created_at: {
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
title: "Created At",
|
||||
},
|
||||
updated_at: {
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
title: "Updated At",
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: [
|
||||
"id",
|
||||
"room_id",
|
||||
"ics_uid",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"last_synced",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
],
|
||||
title: "CalendarEventResponse",
|
||||
} as const;
|
||||
|
||||
export const $CreateParticipant = {
|
||||
properties: {
|
||||
speaker: {
|
||||
@@ -91,6 +193,27 @@ export const $CreateRoom = {
|
||||
type: "boolean",
|
||||
title: "Is Shared",
|
||||
},
|
||||
ics_url: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Ics Url",
|
||||
},
|
||||
ics_fetch_interval: {
|
||||
type: "integer",
|
||||
title: "Ics Fetch Interval",
|
||||
default: 300,
|
||||
},
|
||||
ics_enabled: {
|
||||
type: "boolean",
|
||||
title: "Ics Enabled",
|
||||
default: false,
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: [
|
||||
@@ -687,6 +810,112 @@ export const $HTTPValidationError = {
|
||||
title: "HTTPValidationError",
|
||||
} as const;
|
||||
|
||||
export const $ICSStatus = {
|
||||
properties: {
|
||||
status: {
|
||||
type: "string",
|
||||
title: "Status",
|
||||
},
|
||||
last_sync: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Last Sync",
|
||||
},
|
||||
next_sync: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Next Sync",
|
||||
},
|
||||
last_etag: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Last Etag",
|
||||
},
|
||||
events_count: {
|
||||
type: "integer",
|
||||
title: "Events Count",
|
||||
default: 0,
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: ["status"],
|
||||
title: "ICSStatus",
|
||||
} as const;
|
||||
|
||||
export const $ICSSyncResult = {
|
||||
properties: {
|
||||
status: {
|
||||
type: "string",
|
||||
title: "Status",
|
||||
},
|
||||
hash: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Hash",
|
||||
},
|
||||
events_found: {
|
||||
type: "integer",
|
||||
title: "Events Found",
|
||||
default: 0,
|
||||
},
|
||||
events_created: {
|
||||
type: "integer",
|
||||
title: "Events Created",
|
||||
default: 0,
|
||||
},
|
||||
events_updated: {
|
||||
type: "integer",
|
||||
title: "Events Updated",
|
||||
default: 0,
|
||||
},
|
||||
events_deleted: {
|
||||
type: "integer",
|
||||
title: "Events Deleted",
|
||||
default: 0,
|
||||
},
|
||||
error: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Error",
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: ["status"],
|
||||
title: "ICSSyncResult",
|
||||
} as const;
|
||||
|
||||
export const $Meeting = {
|
||||
properties: {
|
||||
id: {
|
||||
@@ -950,6 +1179,50 @@ export const $Room = {
|
||||
type: "boolean",
|
||||
title: "Is Shared",
|
||||
},
|
||||
ics_url: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Ics Url",
|
||||
},
|
||||
ics_fetch_interval: {
|
||||
type: "integer",
|
||||
title: "Ics Fetch Interval",
|
||||
default: 300,
|
||||
},
|
||||
ics_enabled: {
|
||||
type: "boolean",
|
||||
title: "Ics Enabled",
|
||||
default: false,
|
||||
},
|
||||
ics_last_sync: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Ics Last Sync",
|
||||
},
|
||||
ics_last_etag: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Ics Last Etag",
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: [
|
||||
@@ -1002,7 +1275,7 @@ export const $SearchResponse = {
|
||||
},
|
||||
query: {
|
||||
type: "string",
|
||||
minLength: 0,
|
||||
minLength: 1,
|
||||
title: "Query",
|
||||
description: "Search query text",
|
||||
},
|
||||
@@ -1065,20 +1338,6 @@ export const $SearchResult = {
|
||||
],
|
||||
title: "Room Id",
|
||||
},
|
||||
room_name: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Room Name",
|
||||
},
|
||||
source_kind: {
|
||||
$ref: "#/components/schemas/SourceKind",
|
||||
},
|
||||
created_at: {
|
||||
type: "string",
|
||||
title: "Created At",
|
||||
@@ -1115,18 +1374,10 @@ export const $SearchResult = {
|
||||
title: "Search Snippets",
|
||||
description: "Text snippets around search matches",
|
||||
},
|
||||
total_match_count: {
|
||||
type: "integer",
|
||||
minimum: 0,
|
||||
title: "Total Match Count",
|
||||
description: "Total number of matches found in the transcript",
|
||||
default: 0,
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: [
|
||||
"id",
|
||||
"source_kind",
|
||||
"created_at",
|
||||
"status",
|
||||
"rank",
|
||||
@@ -1316,54 +1567,139 @@ export const $UpdateParticipant = {
|
||||
export const $UpdateRoom = {
|
||||
properties: {
|
||||
name: {
|
||||
type: "string",
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Name",
|
||||
},
|
||||
zulip_auto_post: {
|
||||
type: "boolean",
|
||||
anyOf: [
|
||||
{
|
||||
type: "boolean",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Zulip Auto Post",
|
||||
},
|
||||
zulip_stream: {
|
||||
type: "string",
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Zulip Stream",
|
||||
},
|
||||
zulip_topic: {
|
||||
type: "string",
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Zulip Topic",
|
||||
},
|
||||
is_locked: {
|
||||
type: "boolean",
|
||||
anyOf: [
|
||||
{
|
||||
type: "boolean",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Is Locked",
|
||||
},
|
||||
room_mode: {
|
||||
type: "string",
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Room Mode",
|
||||
},
|
||||
recording_type: {
|
||||
type: "string",
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Recording Type",
|
||||
},
|
||||
recording_trigger: {
|
||||
type: "string",
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Recording Trigger",
|
||||
},
|
||||
is_shared: {
|
||||
type: "boolean",
|
||||
anyOf: [
|
||||
{
|
||||
type: "boolean",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Is Shared",
|
||||
},
|
||||
ics_url: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "string",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Ics Url",
|
||||
},
|
||||
ics_fetch_interval: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "integer",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Ics Fetch Interval",
|
||||
},
|
||||
ics_enabled: {
|
||||
anyOf: [
|
||||
{
|
||||
type: "boolean",
|
||||
},
|
||||
{
|
||||
type: "null",
|
||||
},
|
||||
],
|
||||
title: "Ics Enabled",
|
||||
},
|
||||
},
|
||||
type: "object",
|
||||
required: [
|
||||
"name",
|
||||
"zulip_auto_post",
|
||||
"zulip_stream",
|
||||
"zulip_topic",
|
||||
"is_locked",
|
||||
"room_mode",
|
||||
"recording_type",
|
||||
"recording_trigger",
|
||||
"is_shared",
|
||||
],
|
||||
title: "UpdateRoom",
|
||||
} as const;
|
||||
|
||||
|
||||
@@ -16,6 +16,18 @@ import type {
|
||||
V1RoomsDeleteResponse,
|
||||
V1RoomsCreateMeetingData,
|
||||
V1RoomsCreateMeetingResponse,
|
||||
V1RoomsSyncIcsData,
|
||||
V1RoomsSyncIcsResponse,
|
||||
V1RoomsIcsStatusData,
|
||||
V1RoomsIcsStatusResponse,
|
||||
V1RoomsListMeetingsData,
|
||||
V1RoomsListMeetingsResponse,
|
||||
V1RoomsListUpcomingMeetingsData,
|
||||
V1RoomsListUpcomingMeetingsResponse,
|
||||
V1RoomsListActiveMeetingsData,
|
||||
V1RoomsListActiveMeetingsResponse,
|
||||
V1RoomsJoinMeetingData,
|
||||
V1RoomsJoinMeetingResponse,
|
||||
V1TranscriptsListData,
|
||||
V1TranscriptsListResponse,
|
||||
V1TranscriptsCreateData,
|
||||
@@ -227,6 +239,146 @@ export class DefaultService {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Rooms Sync Ics
|
||||
* @param data The data for the request.
|
||||
* @param data.roomName
|
||||
* @returns ICSSyncResult Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
public v1RoomsSyncIcs(
|
||||
data: V1RoomsSyncIcsData,
|
||||
): CancelablePromise<V1RoomsSyncIcsResponse> {
|
||||
return this.httpRequest.request({
|
||||
method: "POST",
|
||||
url: "/v1/rooms/{room_name}/ics/sync",
|
||||
path: {
|
||||
room_name: data.roomName,
|
||||
},
|
||||
errors: {
|
||||
422: "Validation Error",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Rooms Ics Status
|
||||
* @param data The data for the request.
|
||||
* @param data.roomName
|
||||
* @returns ICSStatus Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
public v1RoomsIcsStatus(
|
||||
data: V1RoomsIcsStatusData,
|
||||
): CancelablePromise<V1RoomsIcsStatusResponse> {
|
||||
return this.httpRequest.request({
|
||||
method: "GET",
|
||||
url: "/v1/rooms/{room_name}/ics/status",
|
||||
path: {
|
||||
room_name: data.roomName,
|
||||
},
|
||||
errors: {
|
||||
422: "Validation Error",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Rooms List Meetings
|
||||
* @param data The data for the request.
|
||||
* @param data.roomName
|
||||
* @returns CalendarEventResponse Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
public v1RoomsListMeetings(
|
||||
data: V1RoomsListMeetingsData,
|
||||
): CancelablePromise<V1RoomsListMeetingsResponse> {
|
||||
return this.httpRequest.request({
|
||||
method: "GET",
|
||||
url: "/v1/rooms/{room_name}/meetings",
|
||||
path: {
|
||||
room_name: data.roomName,
|
||||
},
|
||||
errors: {
|
||||
422: "Validation Error",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Rooms List Upcoming Meetings
|
||||
* @param data The data for the request.
|
||||
* @param data.roomName
|
||||
* @param data.minutesAhead
|
||||
* @returns CalendarEventResponse Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
public v1RoomsListUpcomingMeetings(
|
||||
data: V1RoomsListUpcomingMeetingsData,
|
||||
): CancelablePromise<V1RoomsListUpcomingMeetingsResponse> {
|
||||
return this.httpRequest.request({
|
||||
method: "GET",
|
||||
url: "/v1/rooms/{room_name}/meetings/upcoming",
|
||||
path: {
|
||||
room_name: data.roomName,
|
||||
},
|
||||
query: {
|
||||
minutes_ahead: data.minutesAhead,
|
||||
},
|
||||
errors: {
|
||||
422: "Validation Error",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Rooms List Active Meetings
|
||||
* List all active meetings for a room (supports multiple active meetings)
|
||||
* @param data The data for the request.
|
||||
* @param data.roomName
|
||||
* @returns Meeting Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
public v1RoomsListActiveMeetings(
|
||||
data: V1RoomsListActiveMeetingsData,
|
||||
): CancelablePromise<V1RoomsListActiveMeetingsResponse> {
|
||||
return this.httpRequest.request({
|
||||
method: "GET",
|
||||
url: "/v1/rooms/{room_name}/meetings/active",
|
||||
path: {
|
||||
room_name: data.roomName,
|
||||
},
|
||||
errors: {
|
||||
422: "Validation Error",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Rooms Join Meeting
|
||||
* Join a specific meeting by ID
|
||||
* @param data The data for the request.
|
||||
* @param data.roomName
|
||||
* @param data.meetingId
|
||||
* @returns Meeting Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
public v1RoomsJoinMeeting(
|
||||
data: V1RoomsJoinMeetingData,
|
||||
): CancelablePromise<V1RoomsJoinMeetingResponse> {
|
||||
return this.httpRequest.request({
|
||||
method: "POST",
|
||||
url: "/v1/rooms/{room_name}/meetings/{meeting_id}/join",
|
||||
path: {
|
||||
room_name: data.roomName,
|
||||
meeting_id: data.meetingId,
|
||||
},
|
||||
errors: {
|
||||
422: "Validation Error",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Transcripts List
|
||||
* @param data The data for the request.
|
||||
@@ -286,7 +438,6 @@ export class DefaultService {
|
||||
* @param data.limit Results per page
|
||||
* @param data.offset Number of results to skip
|
||||
* @param data.roomId
|
||||
* @param data.sourceKind
|
||||
* @returns SearchResponse Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
@@ -301,7 +452,6 @@ export class DefaultService {
|
||||
limit: data.limit,
|
||||
offset: data.offset,
|
||||
room_id: data.roomId,
|
||||
source_kind: data.sourceKind,
|
||||
},
|
||||
errors: {
|
||||
422: "Validation Error",
|
||||
|
||||
@@ -9,6 +9,23 @@ export type Body_transcript_record_upload_v1_transcripts__transcript_id__record_
|
||||
chunk: Blob | File;
|
||||
};
|
||||
|
||||
export type CalendarEventResponse = {
|
||||
id: string;
|
||||
room_id: string;
|
||||
ics_uid: string;
|
||||
title?: string | null;
|
||||
description?: string | null;
|
||||
start_time: string;
|
||||
end_time: string;
|
||||
attendees?: Array<{
|
||||
[key: string]: unknown;
|
||||
}> | null;
|
||||
location?: string | null;
|
||||
last_synced: string;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
};
|
||||
|
||||
export type CreateParticipant = {
|
||||
speaker?: number | null;
|
||||
name: string;
|
||||
@@ -24,6 +41,9 @@ export type CreateRoom = {
|
||||
recording_type: string;
|
||||
recording_trigger: string;
|
||||
is_shared: boolean;
|
||||
ics_url?: string | null;
|
||||
ics_fetch_interval?: number;
|
||||
ics_enabled?: boolean;
|
||||
};
|
||||
|
||||
export type CreateTranscript = {
|
||||
@@ -123,6 +143,24 @@ export type HTTPValidationError = {
|
||||
detail?: Array<ValidationError>;
|
||||
};
|
||||
|
||||
export type ICSStatus = {
|
||||
status: string;
|
||||
last_sync?: string | null;
|
||||
next_sync?: string | null;
|
||||
last_etag?: string | null;
|
||||
events_count?: number;
|
||||
};
|
||||
|
||||
export type ICSSyncResult = {
|
||||
status: string;
|
||||
hash?: string | null;
|
||||
events_found?: number;
|
||||
events_created?: number;
|
||||
events_updated?: number;
|
||||
events_deleted?: number;
|
||||
error?: string | null;
|
||||
};
|
||||
|
||||
export type Meeting = {
|
||||
id: string;
|
||||
room_name: string;
|
||||
@@ -174,6 +212,11 @@ export type Room = {
|
||||
recording_type: string;
|
||||
recording_trigger: string;
|
||||
is_shared: boolean;
|
||||
ics_url?: string | null;
|
||||
ics_fetch_interval?: number;
|
||||
ics_enabled?: boolean;
|
||||
ics_last_sync?: string | null;
|
||||
ics_last_etag?: string | null;
|
||||
};
|
||||
|
||||
export type RtcOffer = {
|
||||
@@ -209,8 +252,6 @@ export type SearchResult = {
|
||||
title?: string | null;
|
||||
user_id?: string | null;
|
||||
room_id?: string | null;
|
||||
room_name?: string | null;
|
||||
source_kind: SourceKind;
|
||||
created_at: string;
|
||||
status: string;
|
||||
rank: number;
|
||||
@@ -222,10 +263,6 @@ export type SearchResult = {
|
||||
* Text snippets around search matches
|
||||
*/
|
||||
search_snippets: Array<string>;
|
||||
/**
|
||||
* Total number of matches found in the transcript
|
||||
*/
|
||||
total_match_count?: number;
|
||||
};
|
||||
|
||||
export type SourceKind = "room" | "live" | "file";
|
||||
@@ -272,15 +309,18 @@ export type UpdateParticipant = {
|
||||
};
|
||||
|
||||
export type UpdateRoom = {
|
||||
name: string;
|
||||
zulip_auto_post: boolean;
|
||||
zulip_stream: string;
|
||||
zulip_topic: string;
|
||||
is_locked: boolean;
|
||||
room_mode: string;
|
||||
recording_type: string;
|
||||
recording_trigger: string;
|
||||
is_shared: boolean;
|
||||
name?: string | null;
|
||||
zulip_auto_post?: boolean | null;
|
||||
zulip_stream?: string | null;
|
||||
zulip_topic?: string | null;
|
||||
is_locked?: boolean | null;
|
||||
room_mode?: string | null;
|
||||
recording_type?: string | null;
|
||||
recording_trigger?: string | null;
|
||||
is_shared?: boolean | null;
|
||||
ics_url?: string | null;
|
||||
ics_fetch_interval?: number | null;
|
||||
ics_enabled?: boolean | null;
|
||||
};
|
||||
|
||||
export type UpdateTranscript = {
|
||||
@@ -377,6 +417,44 @@ export type V1RoomsCreateMeetingData = {
|
||||
|
||||
export type V1RoomsCreateMeetingResponse = Meeting;
|
||||
|
||||
export type V1RoomsSyncIcsData = {
|
||||
roomName: string;
|
||||
};
|
||||
|
||||
export type V1RoomsSyncIcsResponse = ICSSyncResult;
|
||||
|
||||
export type V1RoomsIcsStatusData = {
|
||||
roomName: string;
|
||||
};
|
||||
|
||||
export type V1RoomsIcsStatusResponse = ICSStatus;
|
||||
|
||||
export type V1RoomsListMeetingsData = {
|
||||
roomName: string;
|
||||
};
|
||||
|
||||
export type V1RoomsListMeetingsResponse = Array<CalendarEventResponse>;
|
||||
|
||||
export type V1RoomsListUpcomingMeetingsData = {
|
||||
minutesAhead?: number;
|
||||
roomName: string;
|
||||
};
|
||||
|
||||
export type V1RoomsListUpcomingMeetingsResponse = Array<CalendarEventResponse>;
|
||||
|
||||
export type V1RoomsListActiveMeetingsData = {
|
||||
roomName: string;
|
||||
};
|
||||
|
||||
export type V1RoomsListActiveMeetingsResponse = Array<Meeting>;
|
||||
|
||||
export type V1RoomsJoinMeetingData = {
|
||||
meetingId: string;
|
||||
roomName: string;
|
||||
};
|
||||
|
||||
export type V1RoomsJoinMeetingResponse = Meeting;
|
||||
|
||||
export type V1TranscriptsListData = {
|
||||
/**
|
||||
* Page number
|
||||
@@ -413,7 +491,6 @@ export type V1TranscriptsSearchData = {
|
||||
*/
|
||||
q: string;
|
||||
roomId?: string | null;
|
||||
sourceKind?: SourceKind | null;
|
||||
};
|
||||
|
||||
export type V1TranscriptsSearchResponse = SearchResponse;
|
||||
@@ -677,6 +754,96 @@ export type $OpenApiTs = {
|
||||
};
|
||||
};
|
||||
};
|
||||
"/v1/rooms/{room_name}/ics/sync": {
|
||||
post: {
|
||||
req: V1RoomsSyncIcsData;
|
||||
res: {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: ICSSyncResult;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HTTPValidationError;
|
||||
};
|
||||
};
|
||||
};
|
||||
"/v1/rooms/{room_name}/ics/status": {
|
||||
get: {
|
||||
req: V1RoomsIcsStatusData;
|
||||
res: {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: ICSStatus;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HTTPValidationError;
|
||||
};
|
||||
};
|
||||
};
|
||||
"/v1/rooms/{room_name}/meetings": {
|
||||
get: {
|
||||
req: V1RoomsListMeetingsData;
|
||||
res: {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: Array<CalendarEventResponse>;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HTTPValidationError;
|
||||
};
|
||||
};
|
||||
};
|
||||
"/v1/rooms/{room_name}/meetings/upcoming": {
|
||||
get: {
|
||||
req: V1RoomsListUpcomingMeetingsData;
|
||||
res: {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: Array<CalendarEventResponse>;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HTTPValidationError;
|
||||
};
|
||||
};
|
||||
};
|
||||
"/v1/rooms/{room_name}/meetings/active": {
|
||||
get: {
|
||||
req: V1RoomsListActiveMeetingsData;
|
||||
res: {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: Array<Meeting>;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HTTPValidationError;
|
||||
};
|
||||
};
|
||||
};
|
||||
"/v1/rooms/{room_name}/meetings/{meeting_id}/join": {
|
||||
post: {
|
||||
req: V1RoomsJoinMeetingData;
|
||||
res: {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: Meeting;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HTTPValidationError;
|
||||
};
|
||||
};
|
||||
};
|
||||
"/v1/transcripts": {
|
||||
get: {
|
||||
req: V1TranscriptsListData;
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
// TODO better connection with generated schema; it's duplication
|
||||
export const RECORD_A_MEETING_URL = "/transcripts/new" as const;
|
||||
@@ -1,62 +0,0 @@
|
||||
/**
|
||||
* Text highlighting and text fragment generation utilities
|
||||
* Used for search result highlighting and deep linking with Chrome Text Fragments
|
||||
*/
|
||||
|
||||
import React from "react";
|
||||
|
||||
export interface HighlightResult {
|
||||
text: string;
|
||||
matches: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes special regex characters in a string
|
||||
*/
|
||||
function escapeRegex(str: string): string {
|
||||
return str.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
|
||||
}
|
||||
|
||||
export const highlightMatches = (
|
||||
text: string,
|
||||
query: string,
|
||||
): { match: string; index: number }[] => {
|
||||
if (!query || !text) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const queryWords = query.trim().split(/\s+/);
|
||||
|
||||
const regex = new RegExp(
|
||||
`(${queryWords.map((word) => escapeRegex(word)).join("|")})`,
|
||||
"gi",
|
||||
);
|
||||
|
||||
return Array.from(text.matchAll(regex)).map((result) => ({
|
||||
match: result[0],
|
||||
index: result.index!,
|
||||
}));
|
||||
};
|
||||
|
||||
export function findFirstHighlight(text: string, query: string): string | null {
|
||||
const matches = highlightMatches(text, query);
|
||||
if (matches.length === 0) {
|
||||
return null;
|
||||
}
|
||||
return matches[0].match;
|
||||
}
|
||||
|
||||
export function generateTextFragment(
|
||||
text: string,
|
||||
query: string,
|
||||
): {
|
||||
k: ":~:text";
|
||||
v: string;
|
||||
} | null {
|
||||
const firstMatch = findFirstHighlight(text, query);
|
||||
if (!firstMatch) return null;
|
||||
return {
|
||||
k: ":~:text",
|
||||
v: firstMatch,
|
||||
};
|
||||
}
|
||||
@@ -136,10 +136,3 @@ export function extractDomain(url) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function assertExists<T>(value: T | null | undefined, err?: string): T {
|
||||
if (value === null || value === undefined) {
|
||||
throw new Error(`Assertion failed: ${err ?? "value is null or undefined"}`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
import { redirect } from "next/navigation";
|
||||
import { RECORD_A_MEETING_URL } from "./api/urls";
|
||||
|
||||
export default function Index() {
|
||||
redirect(RECORD_A_MEETING_URL);
|
||||
redirect("/transcripts/new");
|
||||
}
|
||||
|
||||
@@ -5,17 +5,14 @@ import system from "./styles/theme";
|
||||
|
||||
import { WherebyProvider } from "@whereby.com/browser-sdk/react";
|
||||
import { Toaster } from "./components/ui/toaster";
|
||||
import { NuqsAdapter } from "nuqs/adapters/next/app";
|
||||
|
||||
export function Providers({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<NuqsAdapter>
|
||||
<ChakraProvider value={system}>
|
||||
<WherebyProvider>
|
||||
{children}
|
||||
<Toaster />
|
||||
</WherebyProvider>
|
||||
</ChakraProvider>
|
||||
</NuqsAdapter>
|
||||
<ChakraProvider value={system}>
|
||||
<WherebyProvider>
|
||||
{children}
|
||||
<Toaster />
|
||||
</WherebyProvider>
|
||||
</ChakraProvider>
|
||||
);
|
||||
}
|
||||
|
||||
147
www/app/room/[roomName]/page.tsx
Normal file
147
www/app/room/[roomName]/page.tsx
Normal file
@@ -0,0 +1,147 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { Box, Spinner, VStack, Text } from "@chakra-ui/react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useApi from "../../lib/useApi";
|
||||
import useSessionStatus from "../../lib/useSessionStatus";
|
||||
import MeetingSelection from "../../[roomName]/MeetingSelection";
|
||||
import { Meeting, Room } from "../../api";
|
||||
|
||||
interface RoomPageProps {
|
||||
params: {
|
||||
roomName: string;
|
||||
};
|
||||
}
|
||||
|
||||
export default function RoomPage({ params }: RoomPageProps) {
|
||||
const { roomName } = params;
|
||||
const router = useRouter();
|
||||
const api = useApi();
|
||||
const { data: session } = useSessionStatus();
|
||||
|
||||
const [room, setRoom] = useState<Room | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [checkingMeetings, setCheckingMeetings] = useState(false);
|
||||
|
||||
const isOwner = session?.user?.id === room?.user_id;
|
||||
|
||||
useEffect(() => {
|
||||
if (!api) return;
|
||||
|
||||
const fetchRoom = async () => {
|
||||
try {
|
||||
// Get room details
|
||||
const roomData = await api.v1RoomsRetrieve({ roomName });
|
||||
setRoom(roomData);
|
||||
|
||||
// Check if we should show meeting selection
|
||||
if (roomData.ics_enabled) {
|
||||
setCheckingMeetings(true);
|
||||
|
||||
// Check for active meetings
|
||||
const activeMeetings = await api.v1RoomsListActiveMeetings({
|
||||
roomName,
|
||||
});
|
||||
|
||||
// Check for upcoming meetings
|
||||
const upcomingEvents = await api.v1RoomsListUpcomingMeetings({
|
||||
roomName,
|
||||
minutesAhead: 30,
|
||||
});
|
||||
|
||||
// If there's only one active meeting and no upcoming, auto-join
|
||||
if (activeMeetings.length === 1 && upcomingEvents.length === 0) {
|
||||
handleMeetingSelect(activeMeetings[0]);
|
||||
} else if (
|
||||
activeMeetings.length === 0 &&
|
||||
upcomingEvents.length === 0
|
||||
) {
|
||||
// No meetings, create unscheduled
|
||||
handleCreateUnscheduled();
|
||||
}
|
||||
// Otherwise, show selection UI (handled by render)
|
||||
} else {
|
||||
// ICS not enabled, use traditional flow
|
||||
handleCreateUnscheduled();
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to fetch room:", err);
|
||||
// Room not found or error
|
||||
router.push("/rooms");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
setCheckingMeetings(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchRoom();
|
||||
}, [api, roomName]);
|
||||
|
||||
const handleMeetingSelect = (meeting: Meeting) => {
|
||||
// Navigate to the classic room page with the meeting
|
||||
// Store meeting in session storage for the classic page to use
|
||||
sessionStorage.setItem(`meeting_${roomName}`, JSON.stringify(meeting));
|
||||
router.push(`/${roomName}`);
|
||||
};
|
||||
|
||||
const handleCreateUnscheduled = async () => {
|
||||
if (!api) return;
|
||||
|
||||
try {
|
||||
// Create a new unscheduled meeting
|
||||
const meeting = await api.v1RoomsCreateMeeting({ roomName });
|
||||
handleMeetingSelect(meeting);
|
||||
} catch (err) {
|
||||
console.error("Failed to create meeting:", err);
|
||||
}
|
||||
};
|
||||
|
||||
if (loading || checkingMeetings) {
|
||||
return (
|
||||
<Box
|
||||
minH="100vh"
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
bg="gray.50"
|
||||
>
|
||||
<VStack spacing={4}>
|
||||
<Spinner size="xl" color="blue.500" />
|
||||
<Text>{loading ? "Loading room..." : "Checking meetings..."}</Text>
|
||||
</VStack>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (!room) {
|
||||
return (
|
||||
<Box
|
||||
minH="100vh"
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
bg="gray.50"
|
||||
>
|
||||
<Text fontSize="lg">Room not found</Text>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Show meeting selection if ICS is enabled and we have multiple options
|
||||
if (room.ics_enabled) {
|
||||
return (
|
||||
<Box minH="100vh" bg="gray.50">
|
||||
<MeetingSelection
|
||||
roomName={roomName}
|
||||
isOwner={isOwner}
|
||||
onMeetingSelect={handleMeetingSelect}
|
||||
onCreateUnscheduled={handleCreateUnscheduled}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Should not reach here - redirected above
|
||||
return null;
|
||||
}
|
||||
277
www/app/room/[roomName]/wait/page.tsx
Normal file
277
www/app/room/[roomName]/wait/page.tsx
Normal file
@@ -0,0 +1,277 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Box,
|
||||
VStack,
|
||||
HStack,
|
||||
Text,
|
||||
Spinner,
|
||||
Progress,
|
||||
Card,
|
||||
CardBody,
|
||||
Button,
|
||||
Icon,
|
||||
} from "@chakra-ui/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { FaClock, FaArrowLeft } from "react-icons/fa";
|
||||
import useApi from "../../../lib/useApi";
|
||||
import { CalendarEventResponse } from "../../../api";
|
||||
|
||||
interface WaitingPageProps {
|
||||
params: {
|
||||
roomName: string;
|
||||
};
|
||||
}
|
||||
|
||||
export default function WaitingPage({ params }: WaitingPageProps) {
|
||||
const { roomName } = params;
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const eventId = searchParams.get("eventId");
|
||||
|
||||
const [event, setEvent] = useState<CalendarEventResponse | null>(null);
|
||||
const [timeRemaining, setTimeRemaining] = useState<number>(0);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [checkingMeeting, setCheckingMeeting] = useState(false);
|
||||
const api = useApi();
|
||||
|
||||
useEffect(() => {
|
||||
if (!api || !eventId) return;
|
||||
|
||||
const fetchEvent = async () => {
|
||||
try {
|
||||
const events = await api.v1RoomsListUpcomingMeetings({
|
||||
roomName,
|
||||
minutesAhead: 60,
|
||||
});
|
||||
|
||||
const targetEvent = events.find((e) => e.id === eventId);
|
||||
if (targetEvent) {
|
||||
setEvent(targetEvent);
|
||||
} else {
|
||||
// Event not found or already started
|
||||
router.push(`/room/${roomName}`);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to fetch event:", err);
|
||||
router.push(`/room/${roomName}`);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchEvent();
|
||||
}, [api, eventId, roomName]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!event) return;
|
||||
|
||||
const updateCountdown = () => {
|
||||
const now = new Date();
|
||||
const start = new Date(event.start_time);
|
||||
const diff = Math.max(0, start.getTime() - now.getTime());
|
||||
|
||||
setTimeRemaining(diff);
|
||||
|
||||
// Check if meeting has started
|
||||
if (diff <= 0) {
|
||||
checkForActiveMeeting();
|
||||
}
|
||||
};
|
||||
|
||||
const checkForActiveMeeting = async () => {
|
||||
if (!api || checkingMeeting) return;
|
||||
|
||||
setCheckingMeeting(true);
|
||||
try {
|
||||
// Check for active meetings
|
||||
const activeMeetings = await api.v1RoomsListActiveMeetings({
|
||||
roomName,
|
||||
});
|
||||
|
||||
// Find meeting for this calendar event
|
||||
const calendarMeeting = activeMeetings.find(
|
||||
(m) => m.calendar_event_id === eventId,
|
||||
);
|
||||
|
||||
if (calendarMeeting) {
|
||||
// Meeting is now active, join it
|
||||
const meeting = await api.v1RoomsJoinMeeting({
|
||||
roomName,
|
||||
meetingId: calendarMeeting.id,
|
||||
});
|
||||
|
||||
// Navigate to the meeting room
|
||||
router.push(`/${roomName}?meetingId=${meeting.id}`);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to check for active meeting:", err);
|
||||
} finally {
|
||||
setCheckingMeeting(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Update countdown every second
|
||||
const interval = setInterval(updateCountdown, 1000);
|
||||
|
||||
// Check for meeting every 10 seconds when close to start time
|
||||
let checkInterval: NodeJS.Timeout | null = null;
|
||||
if (timeRemaining < 60000) {
|
||||
// Less than 1 minute
|
||||
checkInterval = setInterval(checkForActiveMeeting, 10000);
|
||||
}
|
||||
|
||||
updateCountdown(); // Initial update
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
if (checkInterval) clearInterval(checkInterval);
|
||||
};
|
||||
}, [event, api, eventId, roomName, checkingMeeting]);
|
||||
|
||||
const formatTime = (ms: number) => {
|
||||
const totalSeconds = Math.floor(ms / 1000);
|
||||
const hours = Math.floor(totalSeconds / 3600);
|
||||
const minutes = Math.floor((totalSeconds % 3600) / 60);
|
||||
const seconds = totalSeconds % 60;
|
||||
|
||||
if (hours > 0) {
|
||||
return `${hours}:${minutes.toString().padStart(2, "0")}:${seconds
|
||||
.toString()
|
||||
.padStart(2, "0")}`;
|
||||
}
|
||||
return `${minutes}:${seconds.toString().padStart(2, "0")}`;
|
||||
};
|
||||
|
||||
const getProgressValue = () => {
|
||||
if (!event) return 0;
|
||||
|
||||
const now = new Date();
|
||||
const created = new Date(event.created_at);
|
||||
const start = new Date(event.start_time);
|
||||
const totalTime = start.getTime() - created.getTime();
|
||||
const elapsed = now.getTime() - created.getTime();
|
||||
|
||||
return Math.min(100, (elapsed / totalTime) * 100);
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<Box
|
||||
minH="100vh"
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
bg="gray.50"
|
||||
>
|
||||
<VStack spacing={4}>
|
||||
<Spinner size="xl" color="blue.500" />
|
||||
<Text>Loading meeting details...</Text>
|
||||
</VStack>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (!event) {
|
||||
return (
|
||||
<Box
|
||||
minH="100vh"
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
bg="gray.50"
|
||||
>
|
||||
<VStack spacing={4}>
|
||||
<Text fontSize="lg">Meeting not found</Text>
|
||||
<Button
|
||||
leftIcon={<FaArrowLeft />}
|
||||
onClick={() => router.push(`/room/${roomName}`)}
|
||||
>
|
||||
Back to Room
|
||||
</Button>
|
||||
</VStack>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Box
|
||||
minH="100vh"
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
bg="gray.50"
|
||||
>
|
||||
<Card maxW="lg" width="100%" mx={4}>
|
||||
<CardBody>
|
||||
<VStack spacing={6} py={4}>
|
||||
<Icon as={FaClock} boxSize={16} color="blue.500" />
|
||||
|
||||
<VStack spacing={2}>
|
||||
<Text fontSize="2xl" fontWeight="bold">
|
||||
{event.title || "Scheduled Meeting"}
|
||||
</Text>
|
||||
<Text color="gray.600" textAlign="center">
|
||||
The meeting will start automatically when it's time
|
||||
</Text>
|
||||
</VStack>
|
||||
|
||||
<Box width="100%">
|
||||
<Text
|
||||
fontSize="4xl"
|
||||
fontWeight="bold"
|
||||
textAlign="center"
|
||||
color="blue.600"
|
||||
>
|
||||
{formatTime(timeRemaining)}
|
||||
</Text>
|
||||
<Progress
|
||||
value={getProgressValue()}
|
||||
colorScheme="blue"
|
||||
size="sm"
|
||||
mt={4}
|
||||
borderRadius="full"
|
||||
/>
|
||||
</Box>
|
||||
|
||||
{event.description && (
|
||||
<Box width="100%" p={4} bg="gray.100" borderRadius="md">
|
||||
<Text fontSize="sm" fontWeight="semibold" mb={1}>
|
||||
Meeting Description
|
||||
</Text>
|
||||
<Text fontSize="sm" color="gray.700">
|
||||
{event.description}
|
||||
</Text>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
<VStack spacing={3} width="100%">
|
||||
<Text fontSize="sm" color="gray.500">
|
||||
Scheduled for {new Date(event.start_time).toLocaleString()}
|
||||
</Text>
|
||||
|
||||
{checkingMeeting && (
|
||||
<HStack spacing={2}>
|
||||
<Spinner size="sm" color="blue.500" />
|
||||
<Text fontSize="sm" color="blue.600">
|
||||
Checking if meeting has started...
|
||||
</Text>
|
||||
</HStack>
|
||||
)}
|
||||
</VStack>
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
leftIcon={<FaArrowLeft />}
|
||||
onClick={() => router.push(`/room/${roomName}`)}
|
||||
width="100%"
|
||||
>
|
||||
Back to Meeting Selection
|
||||
</Button>
|
||||
</VStack>
|
||||
</CardBody>
|
||||
</Card>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
@@ -31,7 +31,6 @@
|
||||
"next": "^14.2.30",
|
||||
"next-auth": "^4.24.7",
|
||||
"next-themes": "^0.4.6",
|
||||
"nuqs": "^2.4.3",
|
||||
"postcss": "8.4.31",
|
||||
"prop-types": "^15.8.1",
|
||||
"react": "^18.2.0",
|
||||
|
||||
39
www/pnpm-lock.yaml
generated
39
www/pnpm-lock.yaml
generated
@@ -67,9 +67,6 @@ importers:
|
||||
next-themes:
|
||||
specifier: ^0.4.6
|
||||
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
nuqs:
|
||||
specifier: ^2.4.3
|
||||
version: 2.4.3(next@14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0))(react@18.3.1)
|
||||
postcss:
|
||||
specifier: 8.4.31
|
||||
version: 8.4.31
|
||||
@@ -5439,12 +5436,6 @@ packages:
|
||||
}
|
||||
engines: { node: ">= 8" }
|
||||
|
||||
mitt@3.0.1:
|
||||
resolution:
|
||||
{
|
||||
integrity: sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==,
|
||||
}
|
||||
|
||||
mkdirp@0.5.6:
|
||||
resolution:
|
||||
{
|
||||
@@ -5669,27 +5660,6 @@ packages:
|
||||
}
|
||||
deprecated: This package is no longer supported.
|
||||
|
||||
nuqs@2.4.3:
|
||||
resolution:
|
||||
{
|
||||
integrity: sha512-BgtlYpvRwLYiJuWzxt34q2bXu/AIS66sLU1QePIMr2LWkb+XH0vKXdbLSgn9t6p7QKzwI7f38rX3Wl9llTXQ8Q==,
|
||||
}
|
||||
peerDependencies:
|
||||
"@remix-run/react": ">=2"
|
||||
next: ">=14.2.0"
|
||||
react: ">=18.2.0 || ^19.0.0-0"
|
||||
react-router: ^6 || ^7
|
||||
react-router-dom: ^6 || ^7
|
||||
peerDependenciesMeta:
|
||||
"@remix-run/react":
|
||||
optional: true
|
||||
next:
|
||||
optional: true
|
||||
react-router:
|
||||
optional: true
|
||||
react-router-dom:
|
||||
optional: true
|
||||
|
||||
nypm@0.5.4:
|
||||
resolution:
|
||||
{
|
||||
@@ -11583,8 +11553,6 @@ snapshots:
|
||||
minipass: 3.3.6
|
||||
yallist: 4.0.0
|
||||
|
||||
mitt@3.0.1: {}
|
||||
|
||||
mkdirp@0.5.6:
|
||||
dependencies:
|
||||
minimist: 1.2.8
|
||||
@@ -11706,13 +11674,6 @@ snapshots:
|
||||
gauge: 3.0.2
|
||||
set-blocking: 2.0.0
|
||||
|
||||
nuqs@2.4.3(next@14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0))(react@18.3.1):
|
||||
dependencies:
|
||||
mitt: 3.0.1
|
||||
react: 18.3.1
|
||||
optionalDependencies:
|
||||
next: 14.2.31(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(sass@1.90.0)
|
||||
|
||||
nypm@0.5.4:
|
||||
dependencies:
|
||||
citty: 0.1.6
|
||||
|
||||
Reference in New Issue
Block a user