mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2026-02-04 09:56:47 +00:00
feat: worker affinity (#819)
* worker affinity * worker affinity * worker affinity --------- Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
@@ -34,7 +34,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
ENTRYPOINT: beat
|
ENTRYPOINT: beat
|
||||||
|
|
||||||
hatchet-worker:
|
hatchet-worker-cpu:
|
||||||
build:
|
build:
|
||||||
context: server
|
context: server
|
||||||
volumes:
|
volumes:
|
||||||
@@ -43,7 +43,20 @@ services:
|
|||||||
env_file:
|
env_file:
|
||||||
- ./server/.env
|
- ./server/.env
|
||||||
environment:
|
environment:
|
||||||
ENTRYPOINT: hatchet-worker
|
ENTRYPOINT: hatchet-worker-cpu
|
||||||
|
depends_on:
|
||||||
|
hatchet:
|
||||||
|
condition: service_healthy
|
||||||
|
hatchet-worker-llm:
|
||||||
|
build:
|
||||||
|
context: server
|
||||||
|
volumes:
|
||||||
|
- ./server/:/app/
|
||||||
|
- /app/.venv
|
||||||
|
env_file:
|
||||||
|
- ./server/.env
|
||||||
|
environment:
|
||||||
|
ENTRYPOINT: hatchet-worker-llm
|
||||||
depends_on:
|
depends_on:
|
||||||
hatchet:
|
hatchet:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
"""
|
|
||||||
Run Hatchet workers for the multitrack pipeline.
|
|
||||||
Runs as a separate process, just like Celery workers.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
uv run -m reflector.hatchet.run_workers
|
|
||||||
|
|
||||||
# Or via docker:
|
|
||||||
docker compose exec server uv run -m reflector.hatchet.run_workers
|
|
||||||
"""
|
|
||||||
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from hatchet_sdk.rate_limit import RateLimitDuration
|
|
||||||
|
|
||||||
from reflector.hatchet.constants import LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND
|
|
||||||
from reflector.logger import logger
|
|
||||||
from reflector.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
"""Start Hatchet worker polling."""
|
|
||||||
if not settings.HATCHET_ENABLED:
|
|
||||||
logger.error("HATCHET_ENABLED is False, not starting workers")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not settings.HATCHET_CLIENT_TOKEN:
|
|
||||||
logger.error("HATCHET_CLIENT_TOKEN is not set")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Starting Hatchet workers",
|
|
||||||
debug=settings.HATCHET_DEBUG,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Import here (not top-level) - workflow modules call HatchetClientManager.get_client()
|
|
||||||
# at module level because Hatchet SDK decorators (@workflow.task) bind at import time.
|
|
||||||
# Can't use lazy init: decorators need the client object when function is defined.
|
|
||||||
from reflector.hatchet.client import HatchetClientManager # noqa: PLC0415
|
|
||||||
from reflector.hatchet.workflows import ( # noqa: PLC0415
|
|
||||||
daily_multitrack_pipeline,
|
|
||||||
subject_workflow,
|
|
||||||
topic_chunk_workflow,
|
|
||||||
track_workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
hatchet = HatchetClientManager.get_client()
|
|
||||||
|
|
||||||
hatchet.rate_limits.put(
|
|
||||||
LLM_RATE_LIMIT_KEY, LLM_RATE_LIMIT_PER_SECOND, RateLimitDuration.SECOND
|
|
||||||
)
|
|
||||||
|
|
||||||
worker = hatchet.worker(
|
|
||||||
"reflector-pipeline-worker",
|
|
||||||
workflows=[
|
|
||||||
daily_multitrack_pipeline,
|
|
||||||
subject_workflow,
|
|
||||||
topic_chunk_workflow,
|
|
||||||
track_workflow,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def shutdown_handler(signum: int, frame) -> None:
|
|
||||||
logger.info("Received shutdown signal, stopping workers...")
|
|
||||||
# Worker cleanup happens automatically on exit
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, shutdown_handler)
|
|
||||||
signal.signal(signal.SIGTERM, shutdown_handler)
|
|
||||||
|
|
||||||
logger.info("Starting Hatchet worker polling...")
|
|
||||||
worker.start()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
48
server/reflector/hatchet/run_workers_cpu.py
Normal file
48
server/reflector/hatchet/run_workers_cpu.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""
|
||||||
|
CPU-heavy worker pool for audio processing tasks.
|
||||||
|
Handles ONLY: mixdown_tracks
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
- slots=1: Only mixdown (already serialized globally with max_runs=1)
|
||||||
|
- Worker affinity: pool=cpu-heavy
|
||||||
|
"""
|
||||||
|
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
daily_multitrack_pipeline,
|
||||||
|
)
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not settings.HATCHET_ENABLED:
|
||||||
|
logger.error("HATCHET_ENABLED is False, not starting CPU workers")
|
||||||
|
return
|
||||||
|
|
||||||
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting Hatchet CPU worker pool (mixdown only)",
|
||||||
|
worker_name="cpu-worker-pool",
|
||||||
|
slots=1,
|
||||||
|
labels={"pool": "cpu-heavy"},
|
||||||
|
)
|
||||||
|
|
||||||
|
cpu_worker = hatchet.worker(
|
||||||
|
"cpu-worker-pool",
|
||||||
|
slots=1, # Only 1 mixdown at a time (already serialized globally)
|
||||||
|
labels={
|
||||||
|
"pool": "cpu-heavy",
|
||||||
|
},
|
||||||
|
workflows=[daily_multitrack_pipeline],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cpu_worker.start()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Received shutdown signal, stopping CPU workers...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
56
server/reflector/hatchet/run_workers_llm.py
Normal file
56
server/reflector/hatchet/run_workers_llm.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
LLM/I/O worker pool for all non-CPU tasks.
|
||||||
|
Handles: all tasks except mixdown_tracks (transcription, LLM inference, orchestration)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from reflector.hatchet.client import HatchetClientManager
|
||||||
|
from reflector.hatchet.workflows.daily_multitrack_pipeline import (
|
||||||
|
daily_multitrack_pipeline,
|
||||||
|
)
|
||||||
|
from reflector.hatchet.workflows.subject_processing import subject_workflow
|
||||||
|
from reflector.hatchet.workflows.topic_chunk_processing import topic_chunk_workflow
|
||||||
|
from reflector.hatchet.workflows.track_processing import track_workflow
|
||||||
|
from reflector.logger import logger
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
SLOTS = 10
|
||||||
|
WORKER_NAME = "llm-worker-pool"
|
||||||
|
POOL = "llm-io"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not settings.HATCHET_ENABLED:
|
||||||
|
logger.error("HATCHET_ENABLED is False, not starting LLM workers")
|
||||||
|
return
|
||||||
|
|
||||||
|
hatchet = HatchetClientManager.get_client()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting Hatchet LLM worker pool (all tasks except mixdown)",
|
||||||
|
worker_name=WORKER_NAME,
|
||||||
|
slots=SLOTS,
|
||||||
|
labels={"pool": POOL},
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_worker = hatchet.worker(
|
||||||
|
WORKER_NAME,
|
||||||
|
slots=SLOTS, # not all slots are probably used
|
||||||
|
labels={
|
||||||
|
"pool": POOL,
|
||||||
|
},
|
||||||
|
workflows=[
|
||||||
|
daily_multitrack_pipeline,
|
||||||
|
topic_chunk_workflow,
|
||||||
|
subject_workflow,
|
||||||
|
track_workflow,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_worker.start()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Received shutdown signal, stopping LLM workers...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -23,7 +23,12 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Coroutine, Protocol, TypeVar
|
from typing import Any, Callable, Coroutine, Protocol, TypeVar
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from hatchet_sdk import Context
|
from hatchet_sdk import (
|
||||||
|
ConcurrencyExpression,
|
||||||
|
ConcurrencyLimitStrategy,
|
||||||
|
Context,
|
||||||
|
)
|
||||||
|
from hatchet_sdk.labels import DesiredWorkerLabel
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from reflector.dailyco_api.client import DailyApiClient
|
from reflector.dailyco_api.client import DailyApiClient
|
||||||
@@ -467,6 +472,20 @@ async def process_tracks(input: PipelineInput, ctx: Context) -> ProcessTracksRes
|
|||||||
parents=[process_tracks],
|
parents=[process_tracks],
|
||||||
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
execution_timeout=timedelta(seconds=TIMEOUT_AUDIO),
|
||||||
retries=3,
|
retries=3,
|
||||||
|
desired_worker_labels={
|
||||||
|
"pool": DesiredWorkerLabel(
|
||||||
|
value="cpu-heavy",
|
||||||
|
required=True,
|
||||||
|
weight=100,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
concurrency=[
|
||||||
|
ConcurrencyExpression(
|
||||||
|
expression="'mixdown-global'",
|
||||||
|
max_runs=1, # serialize mixdown to prevent resource contention
|
||||||
|
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, # Queue
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
@with_error_handling(TaskName.MIXDOWN_TRACKS)
|
||||||
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
async def mixdown_tracks(input: PipelineInput, ctx: Context) -> MixdownResult:
|
||||||
|
|||||||
@@ -7,7 +7,11 @@ Spawned dynamically by detect_topics via aio_run_many() for parallel processing.
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from hatchet_sdk import ConcurrencyExpression, ConcurrencyLimitStrategy, Context
|
from hatchet_sdk import (
|
||||||
|
ConcurrencyExpression,
|
||||||
|
ConcurrencyLimitStrategy,
|
||||||
|
Context,
|
||||||
|
)
|
||||||
from hatchet_sdk.rate_limit import RateLimit
|
from hatchet_sdk.rate_limit import RateLimit
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -34,11 +38,13 @@ hatchet = HatchetClientManager.get_client()
|
|||||||
topic_chunk_workflow = hatchet.workflow(
|
topic_chunk_workflow = hatchet.workflow(
|
||||||
name="TopicChunkProcessing",
|
name="TopicChunkProcessing",
|
||||||
input_validator=TopicChunkInput,
|
input_validator=TopicChunkInput,
|
||||||
concurrency=ConcurrencyExpression(
|
concurrency=[
|
||||||
expression="'global'", # constant string = global limit across all runs
|
ConcurrencyExpression(
|
||||||
max_runs=20,
|
expression="'global'", # constant string = global limit across all runs
|
||||||
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
|
max_runs=20,
|
||||||
),
|
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,10 @@ elif [ "${ENTRYPOINT}" = "worker" ]; then
|
|||||||
uv run celery -A reflector.worker.app worker --loglevel=info
|
uv run celery -A reflector.worker.app worker --loglevel=info
|
||||||
elif [ "${ENTRYPOINT}" = "beat" ]; then
|
elif [ "${ENTRYPOINT}" = "beat" ]; then
|
||||||
uv run celery -A reflector.worker.app beat --loglevel=info
|
uv run celery -A reflector.worker.app beat --loglevel=info
|
||||||
elif [ "${ENTRYPOINT}" = "hatchet-worker" ]; then
|
elif [ "${ENTRYPOINT}" = "hatchet-worker-cpu" ]; then
|
||||||
uv run python -m reflector.hatchet.run_workers
|
uv run python -m reflector.hatchet.run_workers_cpu
|
||||||
|
elif [ "${ENTRYPOINT}" = "hatchet-worker-llm" ]; then
|
||||||
|
uv run python -m reflector.hatchet.run_workers_llm
|
||||||
else
|
else
|
||||||
echo "Unknown command"
|
echo "Unknown command"
|
||||||
fi
|
fi
|
||||||
|
|||||||
Reference in New Issue
Block a user