diff --git a/server/.gitignore b/server/.gitignore index 7adb7fc0..2a82a747 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -113,7 +113,7 @@ ipython_config.py __pypackages__/ # Celery stuff -celerybeat-schedule +celerybeat-schedule.db celerybeat.pid # SageMath parsed files diff --git a/server/reflector/settings.py b/server/reflector/settings.py index d0ddc91a..f375d5b9 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -131,5 +131,7 @@ class Settings(BaseSettings): # Healthcheck HEALTHCHECK_URL: str | None = None + AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None + settings = Settings() diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py index 5f1e4e74..3fb65c4e 100644 --- a/server/reflector/worker/app.py +++ b/server/reflector/worker/app.py @@ -16,11 +16,17 @@ else: [ "reflector.pipelines.main_live_pipeline", "reflector.worker.healthcheck", + "reflector.worker.process", ] ) # crontab - app.conf.beat_schedule = {} + app.conf.beat_schedule = { + "process_messages": { + "task": "reflector.worker.process.process_messages", + "schedule": 60.0, + } + } if settings.HEALTHCHECK_URL: app.conf.beat_schedule["healthcheck_ping"] = { diff --git a/server/reflector/worker/process.py b/server/reflector/worker/process.py new file mode 100644 index 00000000..4b4e6e27 --- /dev/null +++ b/server/reflector/worker/process.py @@ -0,0 +1,95 @@ +import json +import os +from urllib.parse import unquote + +import av +import boto3 +import structlog +from celery import shared_task +from celery.utils.log import get_task_logger +from reflector.db.transcripts import transcripts_controller +from reflector.pipelines.main_live_pipeline import asynctask, task_pipeline_process +from reflector.settings import settings + +logger = structlog.wrap_logger(get_task_logger(__name__)) + + +@shared_task +def process_messages(): + queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL + if not queue_url: + logger.warning("No process recording queue url") + return + try: + logger.info("Receiving messages from: %s", queue_url) + sqs = boto3.client( + "sqs", + region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, + aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, + ) + + response = sqs.receive_message( + QueueUrl=queue_url, + AttributeNames=["SentTimestamp"], + MaxNumberOfMessages=1, + MessageAttributeNames=["All"], + VisibilityTimeout=0, + WaitTimeSeconds=0, + ) + + for message in response.get("Messages", []): + receipt_handle = message["ReceiptHandle"] + body = json.loads(message["Body"]) + + for record in body.get("Records", []): + if record["eventName"].startswith("ObjectCreated"): + bucket = record["s3"]["bucket"]["name"] + key = unquote(record["s3"]["object"]["key"]) + process_recording.delay(bucket, key) + + sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle) + logger.info("Processed and deleted message: %s", message) + + except Exception as e: + logger.error("process_messages", error=str(e)) + + +@shared_task +@asynctask +async def process_recording(bucket_name: str, object_key: str): + logger.info("Processing recording: %s/%s", bucket_name, object_key) + + transcript = await transcripts_controller.add( + "", + source_language="en", + target_language="en", + user_id=None, + ) + _, extension = os.path.splitext(object_key) + upload_filename = transcript.data_path / f"upload{extension}" + upload_filename.parent.mkdir(parents=True, exist_ok=True) + + s3 = boto3.client( + "s3", + region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, + aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, + ) + + with open(upload_filename, "wb") as f: + s3.download_fileobj(bucket_name, object_key, f) + + container = av.open(upload_filename.as_posix()) + try: + if not len(container.streams.audio): + raise Exception("File has no audio stream") + except Exception: + upload_filename.unlink() + raise + finally: + container.close() + + await transcripts_controller.update(transcript, {"status": "uploaded"}) + + task_pipeline_process.delay(transcript_id=transcript.id)