server: start moving to an external celery task

This commit is contained in:
2023-10-20 18:00:59 +02:00
committed by Mathieu Virbel
parent 16a8579272
commit 8bebb2a769
6 changed files with 392 additions and 1 deletions

210
server/poetry.lock generated
View File

@@ -308,6 +308,20 @@ typing-extensions = ">=4"
[package.extras]
tz = ["python-dateutil"]
[[package]]
name = "amqp"
version = "5.1.1"
description = "Low-level AMQP client for Python (fork of amqplib)."
optional = false
python-versions = ">=3.6"
files = [
{file = "amqp-5.1.1-py3-none-any.whl", hash = "sha256:6f0956d2c23d8fa6e7691934d8c3930eadb44972cbbd1a7ae3a520f735d43359"},
{file = "amqp-5.1.1.tar.gz", hash = "sha256:2c1b13fecc0893e946c65cbd5f36427861cffa4ea2201d8f6fca22e2a373b5e2"},
]
[package.dependencies]
vine = ">=5.0.0"
[[package]]
name = "annotated-types"
version = "0.6.0"
@@ -474,6 +488,17 @@ files = [
{file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"},
]
[[package]]
name = "billiard"
version = "4.1.0"
description = "Python multiprocessing fork with improvements and bugfixes"
optional = false
python-versions = ">=3.7"
files = [
{file = "billiard-4.1.0-py3-none-any.whl", hash = "sha256:0f50d6be051c6b2b75bfbc8bfd85af195c5739c281d3f5b86a5640c65563614a"},
{file = "billiard-4.1.0.tar.gz", hash = "sha256:1ad2eeae8e28053d729ba3373d34d9d6e210f6e4d8bf0a9c64f92bd053f1edf5"},
]
[[package]]
name = "black"
version = "23.9.1"
@@ -556,6 +581,61 @@ urllib3 = ">=1.25.4,<1.27"
[package.extras]
crt = ["awscrt (==0.16.26)"]
[[package]]
name = "celery"
version = "5.3.4"
description = "Distributed Task Queue."
optional = false
python-versions = ">=3.8"
files = [
{file = "celery-5.3.4-py3-none-any.whl", hash = "sha256:1e6ed40af72695464ce98ca2c201ad0ef8fd192246f6c9eac8bba343b980ad34"},
{file = "celery-5.3.4.tar.gz", hash = "sha256:9023df6a8962da79eb30c0c84d5f4863d9793a466354cc931d7f72423996de28"},
]
[package.dependencies]
billiard = ">=4.1.0,<5.0"
click = ">=8.1.2,<9.0"
click-didyoumean = ">=0.3.0"
click-plugins = ">=1.1.1"
click-repl = ">=0.2.0"
kombu = ">=5.3.2,<6.0"
python-dateutil = ">=2.8.2"
tzdata = ">=2022.7"
vine = ">=5.0.0,<6.0"
[package.extras]
arangodb = ["pyArango (>=2.0.2)"]
auth = ["cryptography (==41.0.3)"]
azureblockblob = ["azure-storage-blob (>=12.15.0)"]
brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"]
cassandra = ["cassandra-driver (>=3.25.0,<4)"]
consul = ["python-consul2 (==0.1.5)"]
cosmosdbsql = ["pydocumentdb (==2.3.5)"]
couchbase = ["couchbase (>=3.0.0)"]
couchdb = ["pycouchdb (==1.14.2)"]
django = ["Django (>=2.2.28)"]
dynamodb = ["boto3 (>=1.26.143)"]
elasticsearch = ["elasticsearch (<8.0)"]
eventlet = ["eventlet (>=0.32.0)"]
gevent = ["gevent (>=1.5.0)"]
librabbitmq = ["librabbitmq (>=2.0.0)"]
memcache = ["pylibmc (==1.6.3)"]
mongodb = ["pymongo[srv] (>=4.0.2)"]
msgpack = ["msgpack (==1.0.5)"]
pymemcache = ["python-memcached (==1.59)"]
pyro = ["pyro4 (==4.82)"]
pytest = ["pytest-celery (==0.0.0)"]
redis = ["redis (>=4.5.2,!=4.5.5,<5.0.0)"]
s3 = ["boto3 (>=1.26.143)"]
slmq = ["softlayer-messaging (>=1.0.3)"]
solar = ["ephem (==4.1.4)"]
sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"]
sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.0)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"]
yaml = ["PyYAML (>=3.10)"]
zookeeper = ["kazoo (>=1.3.1)"]
zstd = ["zstandard (==0.21.0)"]
[[package]]
name = "certifi"
version = "2023.7.22"
@@ -744,6 +824,55 @@ files = [
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "click-didyoumean"
version = "0.3.0"
description = "Enables git-like *did-you-mean* feature in click"
optional = false
python-versions = ">=3.6.2,<4.0.0"
files = [
{file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"},
{file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"},
]
[package.dependencies]
click = ">=7"
[[package]]
name = "click-plugins"
version = "1.1.1"
description = "An extension module for click to enable registering CLI commands via setuptools entry-points."
optional = false
python-versions = "*"
files = [
{file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"},
{file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"},
]
[package.dependencies]
click = ">=4.0"
[package.extras]
dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"]
[[package]]
name = "click-repl"
version = "0.3.0"
description = "REPL plugin for Click"
optional = false
python-versions = ">=3.6"
files = [
{file = "click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9"},
{file = "click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812"},
]
[package.dependencies]
click = ">=7.0"
prompt-toolkit = ">=3.0.36"
[package.extras]
testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"]
[[package]]
name = "colorama"
version = "0.4.6"
@@ -1624,6 +1753,38 @@ files = [
cryptography = ">=3.4"
deprecated = "*"
[[package]]
name = "kombu"
version = "5.3.2"
description = "Messaging library for Python."
optional = false
python-versions = ">=3.8"
files = [
{file = "kombu-5.3.2-py3-none-any.whl", hash = "sha256:b753c9cfc9b1e976e637a7cbc1a65d446a22e45546cd996ea28f932082b7dc9e"},
{file = "kombu-5.3.2.tar.gz", hash = "sha256:0ba213f630a2cb2772728aef56ac6883dc3a2f13435e10048f6e97d48506dbbd"},
]
[package.dependencies]
amqp = ">=5.1.1,<6.0.0"
vine = "*"
[package.extras]
azureservicebus = ["azure-servicebus (>=7.10.0)"]
azurestoragequeues = ["azure-identity (>=1.12.0)", "azure-storage-queue (>=12.6.0)"]
confluentkafka = ["confluent-kafka (==2.1.1)"]
consul = ["python-consul2"]
librabbitmq = ["librabbitmq (>=2.0.0)"]
mongodb = ["pymongo (>=4.1.1)"]
msgpack = ["msgpack"]
pyro = ["pyro4"]
qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"]
redis = ["redis (>=4.5.2)"]
slmq = ["softlayer-messaging (>=1.0.3)"]
sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"]
sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"]
yaml = ["PyYAML (>=3.10)"]
zookeeper = ["kazoo (>=2.8.0)"]
[[package]]
name = "levenshtein"
version = "0.21.1"
@@ -2151,6 +2312,20 @@ files = [
fastapi = ">=0.38.1,<1.0.0"
prometheus-client = ">=0.8.0,<1.0.0"
[[package]]
name = "prompt-toolkit"
version = "3.0.39"
description = "Library for building powerful interactive command lines in Python"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"},
{file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"},
]
[package.dependencies]
wcwidth = "*"
[[package]]
name = "protobuf"
version = "4.24.4"
@@ -3438,6 +3613,17 @@ files = [
{file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"},
]
[[package]]
name = "tzdata"
version = "2023.3"
description = "Provider of IANA time zone data"
optional = false
python-versions = ">=2"
files = [
{file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"},
{file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"},
]
[[package]]
name = "urllib3"
version = "1.26.17"
@@ -3523,6 +3709,17 @@ dev = ["Cython (>=0.29.32,<0.30.0)", "Sphinx (>=4.1.2,<4.2.0)", "aiohttp", "flak
docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
test = ["Cython (>=0.29.32,<0.30.0)", "aiohttp", "flake8 (>=3.9.2,<3.10.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=22.0.0,<22.1.0)", "pycodestyle (>=2.7.0,<2.8.0)"]
[[package]]
name = "vine"
version = "5.0.0"
description = "Promises, promises, promises."
optional = false
python-versions = ">=3.6"
files = [
{file = "vine-5.0.0-py2.py3-none-any.whl", hash = "sha256:4c9dceab6f76ed92105027c49c823800dd33cacce13bdedc5b914e3514b7fb30"},
{file = "vine-5.0.0.tar.gz", hash = "sha256:7d3b1624a953da82ef63462013bbd271d3eb75751489f9807598e8f340bd637e"},
]
[[package]]
name = "watchfiles"
version = "0.20.0"
@@ -3557,6 +3754,17 @@ files = [
[package.dependencies]
anyio = ">=3.0.0"
[[package]]
name = "wcwidth"
version = "0.2.8"
description = "Measures the displayed width of unicode strings in a terminal"
optional = false
python-versions = "*"
files = [
{file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"},
{file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"},
]
[[package]]
name = "websockets"
version = "11.0.3"
@@ -3838,4 +4046,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "61578467a70980ff9c2dc0cd787b6410b91d7c5fd2bb4c46b6951ec82690ef67"
content-hash = "fda9f13784a64add559abb2266d60eeef8f28d2b5f369633630f4fed14daa99c"

View File

@@ -33,6 +33,7 @@ prometheus-fastapi-instrumentator = "^6.1.0"
sentencepiece = "^0.1.99"
protobuf = "^4.24.3"
profanityfilter = "^2.0.6"
celery = "^5.3.4"
[tool.poetry.group.dev.dependencies]

View File

@@ -113,5 +113,9 @@ class Settings(BaseSettings):
# Min transcript length to generate topic + summary
MIN_TRANSCRIPT_LENGTH: int = 750
# Celery
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
settings = Settings()

View File

@@ -0,0 +1,2 @@
import reflector.tasks.post_transcript # noqa
import reflector.tasks.worker # noqa

View File

@@ -0,0 +1,170 @@
from reflector.logger import logger
from reflector.processors import (
Pipeline,
Processor,
TranscriptFinalLongSummaryProcessor,
TranscriptFinalShortSummaryProcessor,
TranscriptFinalTitleProcessor,
)
from reflector.processors.base import BroadcastProcessor
from reflector.processors.types import (
FinalLongSummary,
FinalShortSummary,
FinalTitle,
TitleSummary,
)
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.tasks.worker import celery
from reflector.views.rtc_offer import PipelineEvent, TranscriptionContext
from reflector.views.transcripts import Transcript, transcripts_controller
class TranscriptAudioDiarizationProcessor(Processor):
INPUT_TYPE = Transcript
OUTPUT_TYPE = TitleSummary
async def _push(self, data: Transcript):
# Gather diarization data
diarization = [
{"start": 0.0, "stop": 4.9, "speaker": 2},
{"start": 5.6, "stop": 6.7, "speaker": 2},
{"start": 7.3, "stop": 8.9, "speaker": 2},
{"start": 7.3, "stop": 7.9, "speaker": 0},
{"start": 9.4, "stop": 11.2, "speaker": 2},
{"start": 9.7, "stop": 10.0, "speaker": 0},
{"start": 10.0, "stop": 10.1, "speaker": 0},
{"start": 11.7, "stop": 16.1, "speaker": 2},
{"start": 11.8, "stop": 12.1, "speaker": 1},
{"start": 16.4, "stop": 21.0, "speaker": 2},
{"start": 21.1, "stop": 22.6, "speaker": 2},
{"start": 24.7, "stop": 31.9, "speaker": 2},
{"start": 32.0, "stop": 32.8, "speaker": 1},
{"start": 33.4, "stop": 37.8, "speaker": 2},
{"start": 37.9, "stop": 40.3, "speaker": 0},
{"start": 39.2, "stop": 40.4, "speaker": 2},
{"start": 40.7, "stop": 41.4, "speaker": 0},
{"start": 41.6, "stop": 45.7, "speaker": 2},
{"start": 46.4, "stop": 53.1, "speaker": 2},
{"start": 53.6, "stop": 56.5, "speaker": 2},
{"start": 54.9, "stop": 75.4, "speaker": 1},
{"start": 57.3, "stop": 58.0, "speaker": 2},
{"start": 65.7, "stop": 66.0, "speaker": 2},
{"start": 75.8, "stop": 78.8, "speaker": 1},
{"start": 79.0, "stop": 82.6, "speaker": 1},
{"start": 83.2, "stop": 83.3, "speaker": 1},
{"start": 84.5, "stop": 94.3, "speaker": 1},
{"start": 95.1, "stop": 100.7, "speaker": 1},
{"start": 100.7, "stop": 102.0, "speaker": 0},
{"start": 100.7, "stop": 101.8, "speaker": 1},
{"start": 102.0, "stop": 103.0, "speaker": 1},
{"start": 103.0, "stop": 103.7, "speaker": 0},
{"start": 103.7, "stop": 103.8, "speaker": 1},
{"start": 103.8, "stop": 113.9, "speaker": 0},
{"start": 114.7, "stop": 117.0, "speaker": 0},
{"start": 117.0, "stop": 117.4, "speaker": 1},
]
# now reapply speaker to topics (if any)
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
# mutate in place
for topic in data.topics:
for word in topic.words:
for d in diarization:
if d["start"] <= word.start <= d["stop"]:
word.speaker = d["speaker"]
topics = data.topics[:]
await transcripts_controller.update(
data,
{
"topics": [topic.model_dump(mode="json") for topic in data.topics],
},
)
# emit them
for topic in topics:
transcript = ProcessorTranscript(words=topic.words)
await self.emit(
TitleSummary(
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=0,
transcript=transcript,
)
)
@celery.task(name="post_transcript")
async def post_transcript_pipeline(transcript_id: str):
# get transcript
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
logger.error("Transcript not found", transcript_id=transcript_id)
return
ctx = TranscriptionContext(logger=logger.bind(transcript_id=transcript_id))
event_callback = None
event_callback_args = None
async def on_final_short_summary(summary: FinalShortSummary):
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_SHORT_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_long_summary(summary: FinalLongSummary):
ctx.logger.info("FinalLongSummary", final_summary=summary)
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_LONG_SUMMARY,
args=event_callback_args,
data=summary,
)
async def on_final_title(title: FinalTitle):
ctx.logger.info("FinalTitle", final_title=title)
# send to callback (eg. websocket)
if event_callback:
await event_callback(
event=PipelineEvent.FINAL_TITLE,
args=event_callback_args,
data=title,
)
ctx.logger.info("Starting pipeline (diarization)")
ctx.pipeline = Pipeline(
TranscriptAudioDiarizationProcessor(),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(),
TranscriptFinalLongSummaryProcessor.as_threaded(),
TranscriptFinalShortSummaryProcessor.as_threaded(),
]
),
)
await ctx.pipeline.push(transcript)
await ctx.pipeline.flush()
if __name__ == "__main__":
import argparse
import asyncio
parser = argparse.ArgumentParser()
parser.add_argument("transcript_id", type=str)
args = parser.parse_args()
asyncio.run(post_transcript_pipeline(args.transcript_id))

View File

@@ -0,0 +1,6 @@
from celery import Celery
from reflector.settings import settings
celery = Celery(__name__)
celery.conf.broker_url = settings.CELERY_BROKER_URL
celery.conf.result_backend = settings.CELERY_RESULT_BACKEND