From 3b6540eae5b597449f98661bdf15483b77be3268 Mon Sep 17 00:00:00 2001 From: Igor Monadical Date: Tue, 20 Jan 2026 12:27:16 -0500 Subject: [PATCH] feat: worker affinity (#819) * worker affinity * worker affinity * worker affinity --------- Co-authored-by: Igor Loskutov --- docker-compose.yml | 17 +++- server/reflector/hatchet/run_workers.py | 77 ------------------- server/reflector/hatchet/run_workers_cpu.py | 48 ++++++++++++ server/reflector/hatchet/run_workers_llm.py | 56 ++++++++++++++ .../workflows/daily_multitrack_pipeline.py | 21 ++++- .../workflows/topic_chunk_processing.py | 18 +++-- server/runserver.sh | 6 +- 7 files changed, 155 insertions(+), 88 deletions(-) delete mode 100644 server/reflector/hatchet/run_workers.py create mode 100644 server/reflector/hatchet/run_workers_cpu.py create mode 100644 server/reflector/hatchet/run_workers_llm.py diff --git a/docker-compose.yml b/docker-compose.yml index 380abacf..c97deb08 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -34,7 +34,7 @@ services: environment: ENTRYPOINT: beat - hatchet-worker: + hatchet-worker-cpu: build: context: server volumes: @@ -43,7 +43,20 @@ services: env_file: - ./server/.env environment: - ENTRYPOINT: hatchet-worker + ENTRYPOINT: hatchet-worker-cpu + depends_on: + hatchet: + condition: service_healthy + hatchet-worker-llm: + build: + context: server + volumes: + - ./server/:/app/ + - /app/.venv + env_file: + - ./server/.env + environment: + ENTRYPOINT: hatchet-worker-llm depends_on: hatchet: condition: service_healthy diff --git a/server/reflector/hatchet/run_workers.py b/server/reflector/hatchet/run_workers.py deleted file mode 100644 index e6f21653..00000000 --- a/server/reflector/hatchet/run_workers.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Run Hatchet workers for the multitrack pipeline. -Runs as a separate process, just like Celery workers. - -Usage: - uv run -m reflector.hatchet.run_workers - - # Or via docker: - docker compose exec server uv run -m reflector.hatchet.run_workers -""" - -import signal -import sys - -from hatchet_sdk.rate_limit import RateLimitDuration - -from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND -from reflector.logger import logger -from reflector.settings import settings - - -def main() -> None: - """Start Hatchet worker polling.""" - if not settings.HATCHET_ENABLED: - logger.error("HATCHET_ENABLED is False, not starting workers") - sys.exit(1) - - if not settings.HATCHET_CLIENT_TOKEN: - logger.error("HATCHET_CLIENT_TOKEN is not set") - sys.exit(1) - - logger.info( - "Starting Hatchet workers", - debug=settings.HATCHET_DEBUG, - ) - - # Import here (not top-level) - workflow modules call HatchetClientManager.get_client() - # at module level because Hatchet SDK decorators (@workflow.task) bind at import time. - # Can't use lazy init: decorators need the client object when function is defined. - from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415 - from reflector.hatchet.workflows import ( # noqa: PLC0415 - daily_multitrack_pipeline, - subject_workflow, - topic_chunk_workflow, - track_workflow, - ) - - hatchet = HatchetClientManager.get_client() - - hatchet.rate_limits.put( - LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND, RateLimitDuration.SECOND - ) - - worker = hatchet.worker( - "reflector-pipeline-worker", - workflows=[ - daily_multitrack_pipeline, - subject_workflow, - topic_chunk_workflow, - track_workflow, - ], - ) - - def shutdown_handler(signum: int, frame) -> None: - logger.info("Received shutdown signal, stopping workers...") - # Worker cleanup happens automatically on exit - sys.exit(0) - - signal.signal(signal.SIGINT, shutdown_handler) - signal.signal(signal.SIGTERM, shutdown_handler) - - logger.info("Starting Hatchet worker polling...") - worker.start() - - -if __name__ == "__main__": - main() diff --git a/server/reflector/hatchet/run_workers_cpu.py b/server/reflector/hatchet/run_workers_cpu.py new file mode 100644 index 00000000..3fa1106d --- /dev/null +++ b/server/reflector/hatchet/run_workers_cpu.py @@ -0,0 +1,48 @@ +""" +CPU-heavy worker pool for audio processing tasks. +Handles ONLY: mixdown_tracks + +Configuration: +- slots=1: Only mixdown (already serialized globally with max_runs=1) +- Worker affinity: pool=cpu-heavy +""" + +from reflector.hatchet.client import HatchetClientManager +from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + daily_multitrack_pipeline, +) +from reflector.logger import logger +from reflector.settings import settings + + +def main(): + if not settings.HATCHET_ENABLED: + logger.error("HATCHET_ENABLED is False, not starting CPU workers") + return + + hatchet = HatchetClientManager.get_client() + + logger.info( + "Starting Hatchet CPU worker pool (mixdown only)", + worker_name="cpu-worker-pool", + slots=1, + labels={"pool": "cpu-heavy"}, + ) + + cpu_worker = hatchet.worker( + "cpu-worker-pool", + slots=1, # Only 1 mixdown at a time (already serialized globally) + labels={ + "pool": "cpu-heavy", + }, + workflows=[daily_multitrack_pipeline], + ) + + try: + cpu_worker.start() + except KeyboardInterrupt: + logger.info("Received shutdown signal, stopping CPU workers...") + + +if __name__ == "__main__": + main() diff --git a/server/reflector/hatchet/run_workers_llm.py b/server/reflector/hatchet/run_workers_llm.py new file mode 100644 index 00000000..00c3a115 --- /dev/null +++ b/server/reflector/hatchet/run_workers_llm.py @@ -0,0 +1,56 @@ +""" +LLM/I/O worker pool for all non-CPU tasks. +Handles: all tasks except mixdown_tracks (transcription, LLM inference, orchestration) +""" + +from reflector.hatchet.client import HatchetClientManager +from reflector.hatchet.workflows.daily_multitrack_pipeline import ( + daily_multitrack_pipeline, +) +from reflector.hatchet.workflows.subject_processing import subject_workflow +from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow +from reflector.hatchet.workflows.track_processing import track_workflow +from reflector.logger import logger +from reflector.settings import settings + +SLOTS = 10 +WORKER_NAME = "llm-worker-pool" +POOL = "llm-io" + + +def main(): + if not settings.HATCHET_ENABLED: + logger.error("HATCHET_ENABLED is False, not starting LLM workers") + return + + hatchet = HatchetClientManager.get_client() + + logger.info( + "Starting Hatchet LLM worker pool (all tasks except mixdown)", + worker_name=WORKER_NAME, + slots=SLOTS, + labels={"pool": POOL}, + ) + + llm_worker = hatchet.worker( + WORKER_NAME, + slots=SLOTS, # not all slots are probably used + labels={ + "pool": POOL, + }, + workflows=[ + daily_multitrack_pipeline, + topic_chunk_workflow, + subject_workflow, + track_workflow, + ], + ) + + try: + llm_worker.start() + except KeyboardInterrupt: + logger.info("Received shutdown signal, stopping LLM workers...") + + +if __name__ == "__main__": + main() diff --git a/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py b/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py index e9cb39d1..0726cfd6 100644 --- a/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py +++ b/server/reflector/hatchet/workflows/daily_multitrack_pipeline.py @@ -23,7 +23,12 @@ from pathlib import Path from typing import Any, Callable, Coroutine, Protocol, TypeVar import httpx -from hatchet_sdk import Context +from hatchet_sdk import ( + ConcurrencyExpression, + ConcurrencyLimitStrategy, + Context, +) +from hatchet_sdk.labels import DesiredWorkerLabel from pydantic import BaseModel from reflector.dailyco_api.client import DailyApiClient @@ -467,6 +472,20 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes parents=[process_tracks], execution_timeout=timedelta(seconds=TIMEOUT_AUDIO), retries=3, + desired_worker_labels={ + "pool": DesiredWorkerLabel( + value="cpu-heavy", + required=True, + weight=100, + ), + }, + concurrency=[ + ConcurrencyExpression( + expression="'mixdown-global'", + max_runs=1, # serialize mixdown to prevent resource contention + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, # Queue + ) + ], ) @with_error_handling(TaskName.MIXDOWN_TRACKS) async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult: diff --git a/server/reflector/hatchet/workflows/topic_chunk_processing.py b/server/reflector/hatchet/workflows/topic_chunk_processing.py index 6a062b1a..b545b082 100644 --- a/server/reflector/hatchet/workflows/topic_chunk_processing.py +++ b/server/reflector/hatchet/workflows/topic_chunk_processing.py @@ -7,7 +7,11 @@ Spawned dynamically by detect_topics via aio_run_many() for parallel processing. from datetime import timedelta -from hatchet_sdk import ConcurrencyExpression, ConcurrencyLimitStrategy, Context +from hatchet_sdk import ( + ConcurrencyExpression, + ConcurrencyLimitStrategy, + Context, +) from hatchet_sdk.rate_limit import RateLimit from pydantic import BaseModel @@ -34,11 +38,13 @@ hatchet = HatchetClientManager.get_client() topic_chunk_workflow = hatchet.workflow( name="TopicChunkProcessing", input_validator=TopicChunkInput, - concurrency=ConcurrencyExpression( - expression="'global'", # constant string = global limit across all runs - max_runs=20, - limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, - ), + concurrency=[ + ConcurrencyExpression( + expression="'global'", # constant string = global limit across all runs + max_runs=20, + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + ], ) diff --git a/server/runserver.sh b/server/runserver.sh index 3b8976db..68300885 100755 --- a/server/runserver.sh +++ b/server/runserver.sh @@ -7,8 +7,10 @@ elif [ "${ENTRYPOINT}" = "worker" ]; then uv run celery -A reflector.worker.app worker --loglevel=info elif [ "${ENTRYPOINT}" = "beat" ]; then uv run celery -A reflector.worker.app beat --loglevel=info -elif [ "${ENTRYPOINT}" = "hatchet-worker" ]; then - uv run python -m reflector.hatchet.run_workers +elif [ "${ENTRYPOINT}" = "hatchet-worker-cpu" ]; then + uv run python -m reflector.hatchet.run_workers_cpu +elif [ "${ENTRYPOINT}" = "hatchet-worker-llm" ]; then + uv run python -m reflector.hatchet.run_workers_llm else echo "Unknown command" fi