mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-05 18:36:45 +00:00
Compare commits
11 Commits
v0.28.1
...
feat/trans
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b84fd1fc24 | ||
|
|
3652de9fca | ||
|
|
68df825734 | ||
|
|
8ca5324c1a | ||
|
|
39e0b89e67 | ||
|
|
544793a24f | ||
|
|
088451645a | ||
|
|
2dfe82afbc | ||
|
|
b461ebb488 | ||
|
|
0b5112cabc | ||
|
|
316f7b316d |
14
CHANGELOG.md
14
CHANGELOG.md
@@ -1,19 +1,5 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
## [0.28.1](https://github.com/Monadical-SAS/reflector/compare/v0.28.0...v0.28.1) (2026-01-21)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* ics non-sync bugfix ([#823](https://github.com/Monadical-SAS/reflector/issues/823)) ([23d2bc2](https://github.com/Monadical-SAS/reflector/commit/23d2bc283d4d02187b250d2055103e0374ee93d6))
|
|
||||||
|
|
||||||
## [0.28.0](https://github.com/Monadical-SAS/reflector/compare/v0.27.0...v0.28.0) (2026-01-20)
|
|
||||||
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
* worker affinity ([#819](https://github.com/Monadical-SAS/reflector/issues/819)) ([3b6540e](https://github.com/Monadical-SAS/reflector/commit/3b6540eae5b597449f98661bdf15483b77be3268))
|
|
||||||
|
|
||||||
## [0.27.0](https://github.com/Monadical-SAS/reflector/compare/v0.26.0...v0.27.0) (2025-12-26)
|
## [0.27.0](https://github.com/Monadical-SAS/reflector/compare/v0.26.0...v0.27.0) (2025-12-26)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
ENTRYPOINT: beat
|
ENTRYPOINT: beat
|
||||||
|
|
||||||
hatchet-worker-cpu:
|
hatchet-worker:
|
||||||
build:
|
build:
|
||||||
context: server
|
context: server
|
||||||
volumes:
|
volumes:
|
||||||
@@ -43,20 +43,7 @@ services:
|
|||||||
env_file:
|
env_file:
|
||||||
- ./server/.env
|
- ./server/.env
|
||||||
environment:
|
environment:
|
||||||
ENTRYPOINT: hatchet-worker-cpu
|
ENTRYPOINT: hatchet-worker
|
||||||
depends_on:
|
|
||||||
hatchet:
|
|
||||||
condition: service_healthy
|
|
||||||
hatchet-worker-llm:
|
|
||||||
build:
|
|
||||||
context: server
|
|
||||||
volumes:
|
|
||||||
- ./server/:/app/
|
|
||||||
- /app/.venv
|
|
||||||
env_file:
|
|
||||||
- ./server/.env
|
|
||||||
environment:
|
|
||||||
ENTRYPOINT: hatchet-worker-llm
|
|
||||||
depends_on:
|
depends_on:
|
||||||
hatchet:
|
hatchet:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from reflector.views.rooms import router as rooms_router
|
|||||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||||
from reflector.views.transcripts import router as transcripts_router
|
from reflector.views.transcripts import router as transcripts_router
|
||||||
from reflector.views.transcripts_audio import router as transcripts_audio_router
|
from reflector.views.transcripts_audio import router as transcripts_audio_router
|
||||||
|
from reflector.views.transcripts_chat import router as transcripts_chat_router
|
||||||
from reflector.views.transcripts_participants import (
|
from reflector.views.transcripts_participants import (
|
||||||
router as transcripts_participants_router,
|
router as transcripts_participants_router,
|
||||||
)
|
)
|
||||||
@@ -90,6 +91,7 @@ app.include_router(transcripts_participants_router, prefix="/v1")
|
|||||||
app.include_router(transcripts_speaker_router, prefix="/v1")
|
app.include_router(transcripts_speaker_router, prefix="/v1")
|
||||||
app.include_router(transcripts_upload_router, prefix="/v1")
|
app.include_router(transcripts_upload_router, prefix="/v1")
|
||||||
app.include_router(transcripts_websocket_router, prefix="/v1")
|
app.include_router(transcripts_websocket_router, prefix="/v1")
|
||||||
|
app.include_router(transcripts_chat_router, prefix="/v1")
|
||||||
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
||||||
app.include_router(transcripts_process_router, prefix="/v1")
|
app.include_router(transcripts_process_router, prefix="/v1")
|
||||||
app.include_router(user_router, prefix="/v1")
|
app.include_router(user_router, prefix="/v1")
|
||||||
|
|||||||
77
server/reflector/hatchet/run_workers.py
Normal file
77
server/reflector/hatchet/run_workers.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""
|
||||||
|
Run Hatchet workers for the multitrack pipeline.
|
||||||
|
Runs as a separate process, just like Celery workers.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run -m reflector.hatchet.run_workers
|
||||||
|
|
||||||
|
# Or via docker:
|
||||||
|
docker compose exec server uv run -m reflector.hatchet.run_workers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from hatchet_sdk.rate_limit import RateLimitDuration
|
||||||
|
|
||||||
|
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Start Hatchet worker polling."""
|
||||||
|
if not settings.HATCHET_ENABLED:
|
||||||
|
logger.error("HATCHET_ENABLED is False, not starting workers")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not settings.HATCHET_CLIENT_TOKEN:
|
||||||
|
logger.error("HATCHET_CLIENT_TOKEN is not set")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting Hatchet workers",
|
||||||
|
debug=settings.HATCHET_DEBUG,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import here (not top-level) - workflow modules call HatchetClientManager.get_client()
|
||||||
|
# at module level because Hatchet SDK decorators (@workflow.task) bind at import time.
|
||||||
|
# Can't use lazy init: decorators need the client object when function is defined.
|
||||||
|
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
||||||
|
from reflector.hatchet.workflows import ( # noqa: PLC0415
|
||||||
|
daily_multitrack_pipeline,
|
||||||
|
subject_workflow,
|
||||||
|
topic_chunk_workflow,
|
||||||
|
track_workflow,
|
||||||
|
)
|
||||||
|
|
||||||
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
|
hatchet.rate_limits.put(
|
||||||
|
LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND, RateLimitDuration.SECOND
|
||||||
|
)
|
||||||
|
|
||||||
|
worker = hatchet.worker(
|
||||||
|
"reflector-pipeline-worker",
|
||||||
|
workflows=[
|
||||||
|
daily_multitrack_pipeline,
|
||||||
|
subject_workflow,
|
||||||
|
topic_chunk_workflow,
|
||||||
|
track_workflow,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def shutdown_handler(signum: int, frame) -> None:
|
||||||
|
logger.info("Received shutdown signal, stopping workers...")
|
||||||
|
# Worker cleanup happens automatically on exit
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, shutdown_handler)
|
||||||
|
signal.signal(signal.SIGTERM, shutdown_handler)
|
||||||
|
|
||||||
|
logger.info("Starting Hatchet worker polling...")
|
||||||
|
worker.start()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
"""
|
|
||||||
CPU-heavy worker pool for audio processing tasks.
|
|
||||||
Handles ONLY: mixdown_tracks
|
|
||||||
|
|
||||||
Configuration:
|
|
||||||
- slots=1: Only mixdown (already serialized globally with max_runs=1)
|
|
||||||
- Worker affinity: pool=cpu-heavy
|
|
||||||
"""
|
|
||||||
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
|
||||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|
||||||
daily_multitrack_pipeline,
|
|
||||||
)
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
if not settings.HATCHET_ENABLED:
|
|
||||||
logger.error("HATCHET_ENABLED is False, not starting CPU workers")
|
|
||||||
return
|
|
||||||
|
|
||||||
hatchet = HatchetClientManager.get_client()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Starting Hatchet CPU worker pool (mixdown only)",
|
|
||||||
worker_name="cpu-worker-pool",
|
|
||||||
slots=1,
|
|
||||||
labels={"pool": "cpu-heavy"},
|
|
||||||
)
|
|
||||||
|
|
||||||
cpu_worker = hatchet.worker(
|
|
||||||
"cpu-worker-pool",
|
|
||||||
slots=1, # Only 1 mixdown at a time (already serialized globally)
|
|
||||||
labels={
|
|
||||||
"pool": "cpu-heavy",
|
|
||||||
},
|
|
||||||
workflows=[daily_multitrack_pipeline],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
cpu_worker.start()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Received shutdown signal, stopping CPU workers...")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
"""
|
|
||||||
LLM/I/O worker pool for all non-CPU tasks.
|
|
||||||
Handles: all tasks except mixdown_tracks (transcription, LLM inference, orchestration)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from reflector.hatchet.client import HatchetClientManager
|
|
||||||
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
|
||||||
daily_multitrack_pipeline,
|
|
||||||
)
|
|
||||||
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
|
||||||
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
|
||||||
from reflector.hatchet.workflows.track_processing import track_workflow
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
SLOTS = 10
|
|
||||||
WORKER_NAME = "llm-worker-pool"
|
|
||||||
POOL = "llm-io"
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
if not settings.HATCHET_ENABLED:
|
|
||||||
logger.error("HATCHET_ENABLED is False, not starting LLM workers")
|
|
||||||
return
|
|
||||||
|
|
||||||
hatchet = HatchetClientManager.get_client()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Starting Hatchet LLM worker pool (all tasks except mixdown)",
|
|
||||||
worker_name=WORKER_NAME,
|
|
||||||
slots=SLOTS,
|
|
||||||
labels={"pool": POOL},
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_worker = hatchet.worker(
|
|
||||||
WORKER_NAME,
|
|
||||||
slots=SLOTS, # not all slots are probably used
|
|
||||||
labels={
|
|
||||||
"pool": POOL,
|
|
||||||
},
|
|
||||||
workflows=[
|
|
||||||
daily_multitrack_pipeline,
|
|
||||||
topic_chunk_workflow,
|
|
||||||
subject_workflow,
|
|
||||||
track_workflow,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
llm_worker.start()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Received shutdown signal, stopping LLM workers...")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -23,12 +23,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Coroutine, Protocol, TypeVar
|
from typing import Any, Callable, Coroutine, Protocol, TypeVar
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from hatchet_sdk import (
|
from hatchet_sdk import Context
|
||||||
ConcurrencyExpression,
|
|
||||||
ConcurrencyLimitStrategy,
|
|
||||||
Context,
|
|
||||||
)
|
|
||||||
from hatchet_sdk.labels import DesiredWorkerLabel
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from reflector.dailyco_api.client import DailyApiClient
|
from reflector.dailyco_api.client import DailyApiClient
|
||||||
@@ -472,20 +467,6 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
parents=[process_tracks],
|
parents=[process_tracks],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||||
retries=3,
|
retries=3,
|
||||||
desired_worker_labels={
|
|
||||||
"pool": DesiredWorkerLabel(
|
|
||||||
value="cpu-heavy",
|
|
||||||
required=True,
|
|
||||||
weight=100,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
concurrency=[
|
|
||||||
ConcurrencyExpression(
|
|
||||||
expression="'mixdown-global'",
|
|
||||||
max_runs=1, # serialize mixdown to prevent resource contention
|
|
||||||
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, # Queue
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
||||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||||
|
|||||||
@@ -7,11 +7,7 @@ Spawned dynamically by detect_topics via aio_run_many() for parallel processing.
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from hatchet_sdk import (
|
from hatchet_sdk import ConcurrencyExpression, ConcurrencyLimitStrategy, Context
|
||||||
ConcurrencyExpression,
|
|
||||||
ConcurrencyLimitStrategy,
|
|
||||||
Context,
|
|
||||||
)
|
|
||||||
from hatchet_sdk.rate_limit import RateLimit
|
from hatchet_sdk.rate_limit import RateLimit
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -38,13 +34,11 @@ hatchet = HatchetClientManager.get_client()
|
|||||||
topic_chunk_workflow = hatchet.workflow(
|
topic_chunk_workflow = hatchet.workflow(
|
||||||
name="TopicChunkProcessing",
|
name="TopicChunkProcessing",
|
||||||
input_validator=TopicChunkInput,
|
input_validator=TopicChunkInput,
|
||||||
concurrency=[
|
concurrency=ConcurrencyExpression(
|
||||||
ConcurrencyExpression(
|
expression="'global'", # constant string = global limit across all runs
|
||||||
expression="'global'", # constant string = global limit across all runs
|
max_runs=20,
|
||||||
max_runs=20,
|
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
|
||||||
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
|
),
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -319,6 +319,21 @@ class ICSSyncService:
|
|||||||
calendar = self.fetch_service.parse_ics(ics_content)
|
calendar = self.fetch_service.parse_ics(ics_content)
|
||||||
|
|
||||||
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
|
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
|
||||||
|
if room.ics_last_etag == content_hash:
|
||||||
|
logger.info("No changes in ICS for room", room_id=room.id)
|
||||||
|
room_url = f"{settings.UI_BASE_URL}/{room.name}"
|
||||||
|
events, total_events = self.fetch_service.extract_room_events(
|
||||||
|
calendar, room.name, room_url
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": SyncStatus.UNCHANGED,
|
||||||
|
"hash": content_hash,
|
||||||
|
"events_found": len(events),
|
||||||
|
"total_events": total_events,
|
||||||
|
"events_created": 0,
|
||||||
|
"events_updated": 0,
|
||||||
|
"events_deleted": 0,
|
||||||
|
}
|
||||||
|
|
||||||
# Extract matching events
|
# Extract matching events
|
||||||
room_url = f"{settings.UI_BASE_URL}/{room.name}"
|
room_url = f"{settings.UI_BASE_URL}/{room.name}"
|
||||||
@@ -356,44 +371,6 @@ class ICSSyncService:
|
|||||||
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
|
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
|
||||||
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
||||||
|
|
||||||
def _event_data_changed(self, existing: CalendarEvent, new_data: EventData) -> bool:
|
|
||||||
"""Check if event data has changed by comparing relevant fields.
|
|
||||||
|
|
||||||
IMPORTANT: When adding fields to CalendarEvent/EventData, update this method
|
|
||||||
and the _COMPARED_FIELDS set below for runtime validation.
|
|
||||||
"""
|
|
||||||
# Fields that come from ICS and should trigger updates when changed
|
|
||||||
_COMPARED_FIELDS = {
|
|
||||||
"title",
|
|
||||||
"description",
|
|
||||||
"start_time",
|
|
||||||
"end_time",
|
|
||||||
"location",
|
|
||||||
"attendees",
|
|
||||||
"ics_raw_data",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Runtime exhaustiveness check: ensure we're comparing all EventData fields
|
|
||||||
event_data_fields = set(EventData.__annotations__.keys()) - {"ics_uid"}
|
|
||||||
if event_data_fields != _COMPARED_FIELDS:
|
|
||||||
missing = event_data_fields - _COMPARED_FIELDS
|
|
||||||
extra = _COMPARED_FIELDS - event_data_fields
|
|
||||||
raise RuntimeError(
|
|
||||||
f"_event_data_changed() field mismatch: "
|
|
||||||
f"missing={missing}, extra={extra}. "
|
|
||||||
f"Update the comparison logic when adding/removing fields."
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
existing.title != new_data["title"]
|
|
||||||
or existing.description != new_data["description"]
|
|
||||||
or existing.start_time != new_data["start_time"]
|
|
||||||
or existing.end_time != new_data["end_time"]
|
|
||||||
or existing.location != new_data["location"]
|
|
||||||
or existing.attendees != new_data["attendees"]
|
|
||||||
or existing.ics_raw_data != new_data["ics_raw_data"]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _sync_events_to_database(
|
async def _sync_events_to_database(
|
||||||
self, room_id: str, events: list[EventData]
|
self, room_id: str, events: list[EventData]
|
||||||
) -> SyncStats:
|
) -> SyncStats:
|
||||||
@@ -409,14 +386,11 @@ class ICSSyncService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
# Only count as updated if data actually changed
|
updated += 1
|
||||||
if self._event_data_changed(existing, event_data):
|
|
||||||
updated += 1
|
|
||||||
await calendar_events_controller.upsert(calendar_event)
|
|
||||||
else:
|
else:
|
||||||
created += 1
|
created += 1
|
||||||
await calendar_events_controller.upsert(calendar_event)
|
|
||||||
|
|
||||||
|
await calendar_events_controller.upsert(calendar_event)
|
||||||
current_ics_uids.append(event_data["ics_uid"])
|
current_ics_uids.append(event_data["ics_uid"])
|
||||||
|
|
||||||
# Soft delete events that are no longer in calendar
|
# Soft delete events that are no longer in calendar
|
||||||
|
|||||||
133
server/reflector/views/transcripts_chat.py
Normal file
133
server/reflector/views/transcripts_chat.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
Transcripts chat API
|
||||||
|
====================
|
||||||
|
|
||||||
|
WebSocket endpoint for bidirectional chat with LLM about transcript content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
|
from llama_index.core import Settings
|
||||||
|
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
||||||
|
|
||||||
|
from reflector.auth.auth_jwt import JWTAuth
|
||||||
|
from reflector.db.recordings import recordings_controller
|
||||||
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
from reflector.db.users import user_controller
|
||||||
|
from reflector.llm import LLM
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.transcript_formats import topics_to_webvtt_named
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_is_multitrack(transcript) -> bool:
|
||||||
|
"""Detect if transcript is from multitrack recording."""
|
||||||
|
if not transcript.recording_id:
|
||||||
|
return False
|
||||||
|
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||||
|
return recording is not None and recording.is_multitrack
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/transcripts/{transcript_id}/chat")
|
||||||
|
async def transcript_chat_websocket(
|
||||||
|
transcript_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
):
|
||||||
|
"""WebSocket endpoint for chatting with LLM about transcript content."""
|
||||||
|
# 1. Auth check (optional) - extract token from WebSocket subprotocol header
|
||||||
|
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
|
||||||
|
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
|
||||||
|
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
|
||||||
|
token: Optional[str] = None
|
||||||
|
negotiated_subprotocol: Optional[str] = None
|
||||||
|
if len(parts) >= 2 and parts[0].lower() == "bearer":
|
||||||
|
negotiated_subprotocol = "bearer"
|
||||||
|
token = parts[1]
|
||||||
|
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
if token:
|
||||||
|
try:
|
||||||
|
payload = JWTAuth().verify_token(token)
|
||||||
|
authentik_uid = payload.get("sub")
|
||||||
|
|
||||||
|
if authentik_uid:
|
||||||
|
user = await user_controller.get_by_authentik_uid(authentik_uid)
|
||||||
|
if user:
|
||||||
|
user_id = user.id
|
||||||
|
except Exception:
|
||||||
|
# Auth failed - continue as anonymous
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Get transcript (respects user_id for private transcripts)
|
||||||
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
|
transcript_id, user_id=user_id
|
||||||
|
)
|
||||||
|
if not transcript:
|
||||||
|
await websocket.close(code=1008) # Policy violation (not found/unauthorized)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. Accept connection (with negotiated subprotocol if present)
|
||||||
|
await websocket.accept(subprotocol=negotiated_subprotocol)
|
||||||
|
|
||||||
|
# 3. Generate WebVTT context
|
||||||
|
is_multitrack = await _get_is_multitrack(transcript)
|
||||||
|
webvtt = topics_to_webvtt_named(
|
||||||
|
transcript.topics, transcript.participants, is_multitrack
|
||||||
|
)
|
||||||
|
|
||||||
|
# Truncate if needed (15k char limit for POC)
|
||||||
|
webvtt_truncated = webvtt[:15000] if len(webvtt) > 15000 else webvtt
|
||||||
|
|
||||||
|
# 4. Configure LLM
|
||||||
|
llm = LLM(settings=settings, temperature=0.7)
|
||||||
|
|
||||||
|
# 5. System message with transcript context
|
||||||
|
system_msg = f"""You are analyzing this meeting transcript (WebVTT):
|
||||||
|
|
||||||
|
{webvtt_truncated}
|
||||||
|
|
||||||
|
Answer questions about content, speakers, timeline. Include timestamps when relevant."""
|
||||||
|
|
||||||
|
# 6. Conversation history
|
||||||
|
conversation_history = [ChatMessage(role=MessageRole.SYSTEM, content=system_msg)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 7. Message loop
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_json()
|
||||||
|
|
||||||
|
if data.get("type") == "get_context":
|
||||||
|
# Return WebVTT context (for debugging/testing)
|
||||||
|
await websocket.send_json({"type": "context", "webvtt": webvtt})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if data.get("type") != "message":
|
||||||
|
# Echo unknown types for backward compatibility
|
||||||
|
await websocket.send_json({"type": "echo", "data": data})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add user message to history
|
||||||
|
user_msg = ChatMessage(role=MessageRole.USER, content=data.get("text", ""))
|
||||||
|
conversation_history.append(user_msg)
|
||||||
|
|
||||||
|
# Stream LLM response
|
||||||
|
assistant_msg = ""
|
||||||
|
chat_stream = await Settings.llm.astream_chat(conversation_history)
|
||||||
|
async for chunk in chat_stream:
|
||||||
|
token = chunk.delta or ""
|
||||||
|
if token:
|
||||||
|
await websocket.send_json({"type": "token", "text": token})
|
||||||
|
assistant_msg += token
|
||||||
|
|
||||||
|
# Save assistant response to history
|
||||||
|
conversation_history.append(
|
||||||
|
ChatMessage(role=MessageRole.ASSISTANT, content=assistant_msg)
|
||||||
|
)
|
||||||
|
await websocket.send_json({"type": "done"})
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
await websocket.send_json({"type": "error", "message": str(e)})
|
||||||
@@ -7,10 +7,8 @@ elif [ "${ENTRYPOINT}" = "worker" ]; then
|
|||||||
uv run celery -A reflector.worker.app worker --loglevel=info
|
uv run celery -A reflector.worker.app worker --loglevel=info
|
||||||
elif [ "${ENTRYPOINT}" = "beat" ]; then
|
elif [ "${ENTRYPOINT}" = "beat" ]; then
|
||||||
uv run celery -A reflector.worker.app beat --loglevel=info
|
uv run celery -A reflector.worker.app beat --loglevel=info
|
||||||
elif [ "${ENTRYPOINT}" = "hatchet-worker-cpu" ]; then
|
elif [ "${ENTRYPOINT}" = "hatchet-worker" ]; then
|
||||||
uv run python -m reflector.hatchet.run_workers_cpu
|
uv run python -m reflector.hatchet.run_workers
|
||||||
elif [ "${ENTRYPOINT}" = "hatchet-worker-llm" ]; then
|
|
||||||
uv run python -m reflector.hatchet.run_workers_llm
|
|
||||||
else
|
else
|
||||||
echo "Unknown command"
|
echo "Unknown command"
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -189,17 +189,14 @@ async def test_ics_sync_service_sync_room_calendar():
|
|||||||
assert events[0].ics_uid == "sync-event-1"
|
assert events[0].ics_uid == "sync-event-1"
|
||||||
assert events[0].title == "Sync Test Meeting"
|
assert events[0].title == "Sync Test Meeting"
|
||||||
|
|
||||||
# Second sync with same content (calendar unchanged, but sync always runs)
|
# Second sync with same content (should be unchanged)
|
||||||
# Refresh room to get updated etag and force sync by setting old sync time
|
# Refresh room to get updated etag and force sync by setting old sync time
|
||||||
room = await rooms_controller.get_by_id(room.id)
|
room = await rooms_controller.get_by_id(room.id)
|
||||||
await rooms_controller.update(
|
await rooms_controller.update(
|
||||||
room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)}
|
room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)}
|
||||||
)
|
)
|
||||||
result = await sync_service.sync_room_calendar(room)
|
result = await sync_service.sync_room_calendar(room)
|
||||||
assert result["status"] == "success"
|
assert result["status"] == "unchanged"
|
||||||
assert result["events_created"] == 0
|
|
||||||
assert result["events_updated"] == 0
|
|
||||||
assert result["events_deleted"] == 0
|
|
||||||
|
|
||||||
# Third sync with updated event
|
# Third sync with updated event
|
||||||
event["summary"] = "Updated Meeting Title"
|
event["summary"] = "Updated Meeting Title"
|
||||||
@@ -291,43 +288,3 @@ async def test_ics_sync_service_error_handling():
|
|||||||
result = await sync_service.sync_room_calendar(room)
|
result = await sync_service.sync_room_calendar(room)
|
||||||
assert result["status"] == "error"
|
assert result["status"] == "error"
|
||||||
assert "Network error" in result["error"]
|
assert "Network error" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_event_data_changed_exhaustiveness():
|
|
||||||
"""Test that _event_data_changed compares all EventData fields (except ics_uid).
|
|
||||||
|
|
||||||
This test ensures programmers don't forget to update the comparison logic
|
|
||||||
when adding new fields to EventData/CalendarEvent.
|
|
||||||
"""
|
|
||||||
from reflector.services.ics_sync import EventData
|
|
||||||
|
|
||||||
sync_service = ICSSyncService()
|
|
||||||
|
|
||||||
from reflector.db.calendar_events import CalendarEvent
|
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
event_data: EventData = {
|
|
||||||
"ics_uid": "test-123",
|
|
||||||
"title": "Test",
|
|
||||||
"description": "Desc",
|
|
||||||
"location": "Loc",
|
|
||||||
"start_time": now,
|
|
||||||
"end_time": now + timedelta(hours=1),
|
|
||||||
"attendees": [],
|
|
||||||
"ics_raw_data": "raw",
|
|
||||||
}
|
|
||||||
|
|
||||||
existing = CalendarEvent(
|
|
||||||
room_id="room1",
|
|
||||||
**event_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Will raise RuntimeError if fields are missing from comparison
|
|
||||||
result = sync_service._event_data_changed(existing, event_data)
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
modified_data = event_data.copy()
|
|
||||||
modified_data["title"] = "Changed Title"
|
|
||||||
result = sync_service._event_data_changed(existing, modified_data)
|
|
||||||
assert result is True
|
|
||||||
|
|||||||
234
server/tests/test_transcripts_chat.py
Normal file
234
server/tests/test_transcripts_chat.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""Tests for transcript chat WebSocket endpoint."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx_ws import aconnect_ws
|
||||||
|
from uvicorn import Config, Server
|
||||||
|
|
||||||
|
from reflector.db.transcripts import (
|
||||||
|
SourceKind,
|
||||||
|
TranscriptParticipant,
|
||||||
|
TranscriptTopic,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import Word
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chat_appserver(tmpdir, setup_database):
|
||||||
|
"""Start a real HTTP server for WebSocket testing."""
|
||||||
|
from reflector.app import app
|
||||||
|
from reflector.db import get_database
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
DATA_DIR = settings.DATA_DIR
|
||||||
|
settings.DATA_DIR = Path(tmpdir)
|
||||||
|
|
||||||
|
# Start server in separate thread with its own event loop
|
||||||
|
host = "127.0.0.1"
|
||||||
|
port = 1256 # Different port from rtc tests
|
||||||
|
server_started = threading.Event()
|
||||||
|
server_exception = None
|
||||||
|
server_instance = None
|
||||||
|
|
||||||
|
def run_server():
|
||||||
|
nonlocal server_exception, server_instance
|
||||||
|
try:
|
||||||
|
# Create new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
config = Config(app=app, host=host, port=port, loop=loop)
|
||||||
|
server_instance = Server(config)
|
||||||
|
|
||||||
|
async def start_server():
|
||||||
|
# Initialize database connection in this event loop
|
||||||
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
|
await server_instance.serve()
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
|
# Signal that server is starting
|
||||||
|
server_started.set()
|
||||||
|
loop.run_until_complete(start_server())
|
||||||
|
except Exception as e:
|
||||||
|
server_exception = e
|
||||||
|
server_started.set()
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
server_started.wait(timeout=30)
|
||||||
|
if server_exception:
|
||||||
|
raise server_exception
|
||||||
|
|
||||||
|
# Wait for server to be fully ready
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
yield server_instance, host, port
|
||||||
|
|
||||||
|
# Stop server
|
||||||
|
if server_instance:
|
||||||
|
server_instance.should_exit = True
|
||||||
|
server_thread.join(timeout=30)
|
||||||
|
|
||||||
|
settings.DATA_DIR = DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_transcript(setup_database):
|
||||||
|
"""Create a test transcript for WebSocket tests."""
|
||||||
|
transcript = await transcripts_controller.add(
|
||||||
|
name="Test Transcript for Chat", source_kind=SourceKind.FILE
|
||||||
|
)
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_transcript_with_content(setup_database):
|
||||||
|
"""Create a test transcript with actual content for WebVTT generation."""
|
||||||
|
transcript = await transcripts_controller.add(
|
||||||
|
name="Test Transcript with Content", source_kind=SourceKind.FILE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add participants
|
||||||
|
await transcripts_controller.update(
|
||||||
|
transcript,
|
||||||
|
{
|
||||||
|
"participants": [
|
||||||
|
TranscriptParticipant(id="1", speaker=0, name="Alice").model_dump(),
|
||||||
|
TranscriptParticipant(id="2", speaker=1, name="Bob").model_dump(),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add topic with words
|
||||||
|
await transcripts_controller.upsert_topic(
|
||||||
|
transcript,
|
||||||
|
TranscriptTopic(
|
||||||
|
title="Introduction",
|
||||||
|
summary="Opening remarks",
|
||||||
|
timestamp=0.0,
|
||||||
|
words=[
|
||||||
|
Word(text="Hello ", start=0.0, end=1.0, speaker=0),
|
||||||
|
Word(text="everyone.", start=1.0, end=2.0, speaker=0),
|
||||||
|
Word(text="Hi ", start=2.0, end=3.0, speaker=1),
|
||||||
|
Word(text="there!", start=3.0, end=4.0, speaker=1),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_websocket_connection_success(test_transcript, chat_appserver):
|
||||||
|
"""Test successful WebSocket connection to chat endpoint."""
|
||||||
|
server, host, port = chat_appserver
|
||||||
|
base_url = f"ws://{host}:{port}/v1"
|
||||||
|
|
||||||
|
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
||||||
|
# Send unknown message type to test echo behavior
|
||||||
|
await ws.send_json({"type": "test", "text": "Hello"})
|
||||||
|
|
||||||
|
# Should receive echo for unknown types
|
||||||
|
response = await ws.receive_json()
|
||||||
|
assert response["type"] == "echo"
|
||||||
|
assert response["data"]["type"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_websocket_nonexistent_transcript(chat_appserver):
|
||||||
|
"""Test WebSocket connection fails for nonexistent transcript."""
|
||||||
|
server, host, port = chat_appserver
|
||||||
|
base_url = f"ws://{host}:{port}/v1"
|
||||||
|
|
||||||
|
# Connection should fail or disconnect immediately for non-existent transcript
|
||||||
|
# Different behavior from successful connection
|
||||||
|
with pytest.raises(Exception): # Will raise on connection or first operation
|
||||||
|
async with aconnect_ws(f"{base_url}/transcripts/nonexistent-id/chat") as ws:
|
||||||
|
await ws.send_json({"type": "message", "text": "Hello"})
|
||||||
|
await ws.receive_json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_websocket_multiple_messages(test_transcript, chat_appserver):
|
||||||
|
"""Test sending multiple messages through WebSocket."""
|
||||||
|
server, host, port = chat_appserver
|
||||||
|
base_url = f"ws://{host}:{port}/v1"
|
||||||
|
|
||||||
|
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
||||||
|
# Send multiple unknown message types (testing echo behavior)
|
||||||
|
messages = ["First message", "Second message", "Third message"]
|
||||||
|
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
await ws.send_json({"type": f"test{i}", "text": msg})
|
||||||
|
response = await ws.receive_json()
|
||||||
|
assert response["type"] == "echo"
|
||||||
|
assert response["data"]["type"] == f"test{i}"
|
||||||
|
assert response["data"]["text"] == msg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_websocket_disconnect_graceful(test_transcript, chat_appserver):
|
||||||
|
"""Test WebSocket disconnects gracefully."""
|
||||||
|
server, host, port = chat_appserver
|
||||||
|
base_url = f"ws://{host}:{port}/v1"
|
||||||
|
|
||||||
|
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
||||||
|
await ws.send_json({"type": "message", "text": "Hello"})
|
||||||
|
await ws.receive_json()
|
||||||
|
# Close handled by context manager - should not raise
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_websocket_context_generation(
|
||||||
|
test_transcript_with_content, chat_appserver
|
||||||
|
):
|
||||||
|
"""Test WebVTT context is generated on connection."""
|
||||||
|
server, host, port = chat_appserver
|
||||||
|
base_url = f"ws://{host}:{port}/v1"
|
||||||
|
|
||||||
|
async with aconnect_ws(
|
||||||
|
f"{base_url}/transcripts/{test_transcript_with_content.id}/chat"
|
||||||
|
) as ws:
|
||||||
|
# Request context
|
||||||
|
await ws.send_json({"type": "get_context"})
|
||||||
|
|
||||||
|
# Receive context response
|
||||||
|
response = await ws.receive_json()
|
||||||
|
assert response["type"] == "context"
|
||||||
|
assert "webvtt" in response
|
||||||
|
|
||||||
|
# Verify WebVTT format
|
||||||
|
webvtt = response["webvtt"]
|
||||||
|
assert webvtt.startswith("WEBVTT")
|
||||||
|
assert "<v Alice>" in webvtt
|
||||||
|
assert "<v Bob>" in webvtt
|
||||||
|
assert "Hello everyone." in webvtt
|
||||||
|
assert "Hi there!" in webvtt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_websocket_unknown_message_type(test_transcript, chat_appserver):
|
||||||
|
"""Test unknown message types are echoed back."""
|
||||||
|
server, host, port = chat_appserver
|
||||||
|
base_url = f"ws://{host}:{port}/v1"
|
||||||
|
|
||||||
|
async with aconnect_ws(f"{base_url}/transcripts/{test_transcript.id}/chat") as ws:
|
||||||
|
# Send unknown message type
|
||||||
|
await ws.send_json({"type": "unknown", "data": "test"})
|
||||||
|
|
||||||
|
# Should receive echo
|
||||||
|
response = await ws.receive_json()
|
||||||
|
assert response["type"] == "echo"
|
||||||
|
assert response["data"]["type"] == "unknown"
|
||||||
103
www/app/(app)/transcripts/TranscriptChatModal.tsx
Normal file
103
www/app/(app)/transcripts/TranscriptChatModal.tsx
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState } from "react";
|
||||||
|
import { Box, Dialog, Input, IconButton } from "@chakra-ui/react";
|
||||||
|
import { MessageCircle } from "lucide-react";
|
||||||
|
import Markdown from "react-markdown";
|
||||||
|
import "../../styles/markdown.css";
|
||||||
|
import type { Message } from "./useTranscriptChat";
|
||||||
|
|
||||||
|
interface TranscriptChatModalProps {
|
||||||
|
open: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
messages: Message[];
|
||||||
|
sendMessage: (text: string) => void;
|
||||||
|
isStreaming: boolean;
|
||||||
|
currentStreamingText: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function TranscriptChatModal({
|
||||||
|
open,
|
||||||
|
onClose,
|
||||||
|
messages,
|
||||||
|
sendMessage,
|
||||||
|
isStreaming,
|
||||||
|
currentStreamingText,
|
||||||
|
}: TranscriptChatModalProps) {
|
||||||
|
const [input, setInput] = useState("");
|
||||||
|
|
||||||
|
const handleSend = () => {
|
||||||
|
if (!input.trim()) return;
|
||||||
|
sendMessage(input);
|
||||||
|
setInput("");
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog.Root open={open} onOpenChange={(e) => !e.open && onClose()}>
|
||||||
|
<Dialog.Backdrop />
|
||||||
|
<Dialog.Positioner>
|
||||||
|
<Dialog.Content maxW="500px" h="600px">
|
||||||
|
<Dialog.Header>Transcript Chat</Dialog.Header>
|
||||||
|
|
||||||
|
<Dialog.Body overflowY="auto">
|
||||||
|
{messages.map((msg) => (
|
||||||
|
<Box
|
||||||
|
key={msg.id}
|
||||||
|
p={3}
|
||||||
|
mb={2}
|
||||||
|
bg={msg.role === "user" ? "blue.50" : "gray.50"}
|
||||||
|
borderRadius="md"
|
||||||
|
>
|
||||||
|
{msg.role === "user" ? (
|
||||||
|
msg.text
|
||||||
|
) : (
|
||||||
|
<div className="markdown">
|
||||||
|
<Markdown>{msg.text}</Markdown>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
))}
|
||||||
|
|
||||||
|
{isStreaming && (
|
||||||
|
<Box p={3} bg="gray.50" borderRadius="md">
|
||||||
|
<div className="markdown">
|
||||||
|
<Markdown>{currentStreamingText}</Markdown>
|
||||||
|
</div>
|
||||||
|
<Box as="span" className="animate-pulse">
|
||||||
|
▊
|
||||||
|
</Box>
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</Dialog.Body>
|
||||||
|
|
||||||
|
<Dialog.Footer>
|
||||||
|
<Input
|
||||||
|
value={input}
|
||||||
|
onChange={(e) => setInput(e.target.value)}
|
||||||
|
onKeyDown={(e) => e.key === "Enter" && handleSend()}
|
||||||
|
placeholder="Ask about transcript..."
|
||||||
|
disabled={isStreaming}
|
||||||
|
/>
|
||||||
|
</Dialog.Footer>
|
||||||
|
</Dialog.Content>
|
||||||
|
</Dialog.Positioner>
|
||||||
|
</Dialog.Root>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function TranscriptChatButton({ onClick }: { onClick: () => void }) {
|
||||||
|
return (
|
||||||
|
<IconButton
|
||||||
|
position="fixed"
|
||||||
|
bottom="24px"
|
||||||
|
right="24px"
|
||||||
|
onClick={onClick}
|
||||||
|
size="lg"
|
||||||
|
colorPalette="blue"
|
||||||
|
borderRadius="full"
|
||||||
|
aria-label="Open chat"
|
||||||
|
>
|
||||||
|
<MessageCircle />
|
||||||
|
</IconButton>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -18,9 +18,15 @@ import {
|
|||||||
Skeleton,
|
Skeleton,
|
||||||
Text,
|
Text,
|
||||||
Spinner,
|
Spinner,
|
||||||
|
useDisclosure,
|
||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { useTranscriptGet } from "../../../lib/apiHooks";
|
import { useTranscriptGet } from "../../../lib/apiHooks";
|
||||||
import { TranscriptStatus } from "../../../lib/transcript";
|
import { TranscriptStatus } from "../../../lib/transcript";
|
||||||
|
import {
|
||||||
|
TranscriptChatModal,
|
||||||
|
TranscriptChatButton,
|
||||||
|
} from "../TranscriptChatModal";
|
||||||
|
import { useTranscriptChat } from "../useTranscriptChat";
|
||||||
|
|
||||||
type TranscriptDetails = {
|
type TranscriptDetails = {
|
||||||
params: Promise<{
|
params: Promise<{
|
||||||
@@ -53,6 +59,9 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
|||||||
const [finalSummaryElement, setFinalSummaryElement] =
|
const [finalSummaryElement, setFinalSummaryElement] =
|
||||||
useState<HTMLDivElement | null>(null);
|
useState<HTMLDivElement | null>(null);
|
||||||
|
|
||||||
|
const { open, onOpen, onClose } = useDisclosure();
|
||||||
|
const chat = useTranscriptChat(transcriptId);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!waiting || !transcript.data) return;
|
if (!waiting || !transcript.data) return;
|
||||||
|
|
||||||
@@ -119,6 +128,15 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
|
<TranscriptChatModal
|
||||||
|
open={open}
|
||||||
|
onClose={onClose}
|
||||||
|
messages={chat.messages}
|
||||||
|
sendMessage={chat.sendMessage}
|
||||||
|
isStreaming={chat.isStreaming}
|
||||||
|
currentStreamingText={chat.currentStreamingText}
|
||||||
|
/>
|
||||||
|
<TranscriptChatButton onClick={onOpen} />
|
||||||
<Grid
|
<Grid
|
||||||
templateColumns="1fr"
|
templateColumns="1fr"
|
||||||
templateRows="auto minmax(0, 1fr)"
|
templateRows="auto minmax(0, 1fr)"
|
||||||
|
|||||||
130
www/app/(app)/transcripts/useTranscriptChat.ts
Normal file
130
www/app/(app)/transcripts/useTranscriptChat.ts
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useState, useRef } from "react";
|
||||||
|
import { getSession } from "next-auth/react";
|
||||||
|
import { WEBSOCKET_URL } from "../../lib/apiClient";
|
||||||
|
import { assertExtendedToken } from "../../lib/types";
|
||||||
|
|
||||||
|
export type Message = {
|
||||||
|
id: string;
|
||||||
|
role: "user" | "assistant";
|
||||||
|
text: string;
|
||||||
|
timestamp: Date;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type UseTranscriptChat = {
|
||||||
|
messages: Message[];
|
||||||
|
sendMessage: (text: string) => void;
|
||||||
|
isStreaming: boolean;
|
||||||
|
currentStreamingText: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useTranscriptChat = (transcriptId: string): UseTranscriptChat => {
|
||||||
|
const [messages, setMessages] = useState<Message[]>([]);
|
||||||
|
const [isStreaming, setIsStreaming] = useState(false);
|
||||||
|
const [currentStreamingText, setCurrentStreamingText] = useState("");
|
||||||
|
const wsRef = useRef<WebSocket | null>(null);
|
||||||
|
const streamingTextRef = useRef<string>("");
|
||||||
|
const isMountedRef = useRef<boolean>(true);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
isMountedRef.current = true;
|
||||||
|
|
||||||
|
const connectWebSocket = async () => {
|
||||||
|
const url = `${WEBSOCKET_URL}/v1/transcripts/${transcriptId}/chat`;
|
||||||
|
|
||||||
|
// Get auth token for WebSocket subprotocol
|
||||||
|
let protocols: string[] | undefined;
|
||||||
|
try {
|
||||||
|
const session = await getSession();
|
||||||
|
if (session) {
|
||||||
|
const token = assertExtendedToken(session).accessToken;
|
||||||
|
// Pass token via subprotocol: ["bearer", token]
|
||||||
|
protocols = ["bearer", token];
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn("Failed to get auth token for WebSocket:", error);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ws = new WebSocket(url, protocols);
|
||||||
|
wsRef.current = ws;
|
||||||
|
|
||||||
|
ws.onopen = () => {
|
||||||
|
console.log("Chat WebSocket connected");
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
if (!isMountedRef.current) return;
|
||||||
|
|
||||||
|
const msg = JSON.parse(event.data);
|
||||||
|
|
||||||
|
switch (msg.type) {
|
||||||
|
case "token":
|
||||||
|
setIsStreaming(true);
|
||||||
|
streamingTextRef.current += msg.text;
|
||||||
|
setCurrentStreamingText(streamingTextRef.current);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "done":
|
||||||
|
// CRITICAL: Save the text BEFORE resetting the ref
|
||||||
|
// The setMessages callback may execute later, after ref is reset
|
||||||
|
const finalText = streamingTextRef.current;
|
||||||
|
|
||||||
|
setMessages((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: Date.now().toString(),
|
||||||
|
role: "assistant",
|
||||||
|
text: finalText,
|
||||||
|
timestamp: new Date(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
streamingTextRef.current = "";
|
||||||
|
setCurrentStreamingText("");
|
||||||
|
setIsStreaming(false);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
console.error("Chat error:", msg.message);
|
||||||
|
setIsStreaming(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = (error) => {
|
||||||
|
console.error("WebSocket error:", error);
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = () => {
|
||||||
|
console.log("Chat WebSocket closed");
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
connectWebSocket();
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
isMountedRef.current = false;
|
||||||
|
if (wsRef.current) {
|
||||||
|
wsRef.current.close();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, [transcriptId]);
|
||||||
|
|
||||||
|
const sendMessage = (text: string) => {
|
||||||
|
if (!wsRef.current) return;
|
||||||
|
|
||||||
|
setMessages((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
id: Date.now().toString(),
|
||||||
|
role: "user",
|
||||||
|
text,
|
||||||
|
timestamp: new Date(),
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
wsRef.current.send(JSON.stringify({ type: "message", text }));
|
||||||
|
};
|
||||||
|
|
||||||
|
return { messages, sendMessage, isStreaming, currentStreamingText };
|
||||||
|
};
|
||||||
Reference in New Issue
Block a user