From d94e2911c3f9addf4e261c1407bda92ad4b9d41f Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 4 Aug 2023 10:24:11 +0200 Subject: [PATCH] Serverless GPU support on banana.dev (#106) * serverless: implement banana backend for both audio and LLM Related to monadical-sas/reflector-gpu-banana project * serverless: got llm working on banana ! * tests: fixes * serverless: fix dockerfile to use fastapi server + httpx --- server/Dockerfile | 4 +- server/poetry.lock | 243 +++++++++++++++++- server/pyproject.toml | 6 +- server/reflector/llm/__init__.py | 2 - server/reflector/llm/base.py | 32 ++- server/reflector/llm/llm_banana.py | 41 +++ .../processors/audio_transcript_auto.py | 37 ++- .../processors/audio_transcript_banana.py | 85 ++++++ .../processors/audio_transcript_whisper.py | 4 + .../processors/transcript_topic_detector.py | 2 +- server/reflector/settings.py | 29 ++- server/reflector/storage/__init__.py | 1 + server/reflector/storage/base.py | 47 ++++ server/reflector/storage/storage_aws.py | 67 +++++ server/reflector/tools/process.py | 25 +- server/reflector/utils/retry.py | 29 +++ server/tests/test_processors_pipeline.py | 1 + 17 files changed, 602 insertions(+), 53 deletions(-) create mode 100644 server/reflector/llm/llm_banana.py create mode 100644 server/reflector/processors/audio_transcript_banana.py create mode 100644 server/reflector/storage/__init__.py create mode 100644 server/reflector/storage/base.py create mode 100644 server/reflector/storage/storage_aws.py create mode 100644 server/reflector/utils/retry.py diff --git a/server/Dockerfile b/server/Dockerfile index 248bea9b..7a0aa8f7 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -18,7 +18,7 @@ COPY pyproject.toml poetry.lock /tmp RUN pip install "poetry==$POETRY_VERSION" RUN python -m venv /venv RUN . /venv/bin/activate && poetry config virtualenvs.create false -RUN . /venv/bin/activate && poetry install --only main --no-root --no-interaction --no-ansi +RUN . /venv/bin/activate && poetry install --only main,aws --no-root --no-interaction --no-ansi # bootstrap FROM base AS final @@ -26,4 +26,4 @@ COPY --from=builder /venv /venv RUN mkdir -p /app COPY reflector /app/reflector WORKDIR /app -CMD ["/venv/bin/python", "-m", "reflector.server"] +CMD ["/venv/bin/python", "-m", "reflector.app"] diff --git a/server/poetry.lock b/server/poetry.lock index bef08557..71206cba 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,5 +1,45 @@ # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +[[package]] +name = "aioboto3" +version = "11.2.0" +description = "Async boto3 wrapper" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "aioboto3-11.2.0-py3-none-any.whl", hash = "sha256:df4b83c3943b009a4dcd9f397f9f0491a374511b1ef37545082a771ca1e549fb"}, + {file = "aioboto3-11.2.0.tar.gz", hash = "sha256:c7f6234fd73efcb60ab6fca383fec33bb6352ca1832f252eac810cd6674f1748"}, +] + +[package.dependencies] +aiobotocore = {version = "2.5.0", extras = ["boto3"]} + +[package.extras] +chalice = ["chalice (>=1.24.0)"] +s3cse = ["cryptography (>=2.3.1)"] + +[[package]] +name = "aiobotocore" +version = "2.5.0" +description = "Async client for aws services using botocore and aiohttp" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiobotocore-2.5.0-py3-none-any.whl", hash = "sha256:9a2a022d7b78ec9a2af0de589916d2721cddbf96264401b78d7a73c1a1435f3b"}, + {file = "aiobotocore-2.5.0.tar.gz", hash = "sha256:6a5b397cddd4f81026aa91a14c7dd2650727425740a5af8ba75127ff663faf67"}, +] + +[package.dependencies] +aiohttp = ">=3.3.1" +aioitertools = ">=0.5.1" +boto3 = {version = ">=1.26.76,<1.26.77", optional = true, markers = "extra == \"boto3\""} +botocore = ">=1.29.76,<1.29.77" +wrapt = ">=1.10.10" + +[package.extras] +awscli = ["awscli (>=1.27.76,<1.27.77)"] +boto3 = ["boto3 (>=1.26.76,<1.26.77)"] + [[package]] name = "aiohttp" version = "3.8.5" @@ -137,6 +177,17 @@ files = [ dnspython = ">=2.0.0" ifaddr = ">=0.2.0" +[[package]] +name = "aioitertools" +version = "0.11.0" +description = "itertools and builtins for AsyncIO and mixed iterables" +optional = false +python-versions = ">=3.6" +files = [ + {file = "aioitertools-0.11.0-py3-none-any.whl", hash = "sha256:04b95e3dab25b449def24d7df809411c10e62aab0cbe31a50ca4e68748c43394"}, + {file = "aioitertools-0.11.0.tar.gz", hash = "sha256:42c68b8dd3a69c2bf7f2233bf7df4bb58b557bca5252ac02ed5187bbc67d6831"}, +] + [[package]] name = "aiortc" version = "1.5.0" @@ -380,6 +431,44 @@ d = ["aiohttp (>=3.7.4)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "boto3" +version = "1.26.76" +description = "The AWS SDK for Python" +optional = false +python-versions = ">= 3.7" +files = [ + {file = "boto3-1.26.76-py3-none-any.whl", hash = "sha256:b4c2969b7677762914394b8273cc1905dfe5b71f250741c1a575487ae357e729"}, + {file = "boto3-1.26.76.tar.gz", hash = "sha256:30c7d967ed1c6b5a05643e42cae9d4d36c3f1cb6782637ddc7007a104cfd9027"}, +] + +[package.dependencies] +botocore = ">=1.29.76,<1.30.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.6.0,<0.7.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.29.76" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">= 3.7" +files = [ + {file = "botocore-1.29.76-py3-none-any.whl", hash = "sha256:70735b00cd529f152992231ca6757e458e5ec25db43767b3526e9a35b2f143b7"}, + {file = "botocore-1.29.76.tar.gz", hash = "sha256:c2f67b6b3f8acf2968eafca06526f07b9fb0d27bac4c68a635d51abb675134a7"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = ">=1.25.4,<1.27" + +[package.extras] +crt = ["awscrt (==0.16.9)"] + [[package]] name = "certifi" version = "2023.7.22" @@ -1127,6 +1216,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "loguru" version = "0.7.0" @@ -1751,6 +1851,20 @@ pytest = ">=7.0.0" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] + +[package.dependencies] +six = ">=1.5" + [[package]] name = "python-dotenv" version = "1.0.0" @@ -1835,6 +1949,23 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "s3transfer" +version = "0.6.1" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">= 3.7" +files = [ + {file = "s3transfer-0.6.1-py3-none-any.whl", hash = "sha256:3c0da2d074bf35d6870ef157158641178a4204a6e689e82546083e31e0311346"}, + {file = "s3transfer-0.6.1.tar.gz", hash = "sha256:640bb492711f4c0c0905e1f62b6aaeb771881935ad27884852411f8e9cacbca9"}, +] + +[package.dependencies] +botocore = ">=1.12.36,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] + [[package]] name = "sentry-sdk" version = "1.29.2" @@ -1878,6 +2009,17 @@ starlette = ["starlette (>=0.19.1)"] starlite = ["starlite (>=1.48)"] tornado = ["tornado (>=5)"] +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "sniffio" version = "1.3.0" @@ -2069,20 +2211,19 @@ files = [ [[package]] name = "urllib3" -version = "2.0.4" +version = "1.26.16" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = ">=3.7" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ - {file = "urllib3-2.0.4-py3-none-any.whl", hash = "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"}, - {file = "urllib3-2.0.4.tar.gz", hash = "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11"}, + {file = "urllib3-1.26.16-py2.py3-none-any.whl", hash = "sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f"}, + {file = "urllib3-1.26.16.tar.gz", hash = "sha256:8f135f6502756bde6b2a9b28989df5fbe87c9970cecaa69041edcce7f0589b14"}, ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] -socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] -zstd = ["zstandard (>=0.18.0)"] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [[package]] name = "uvicorn" @@ -2280,6 +2421,90 @@ files = [ [package.extras] dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] +[[package]] +name = "wrapt" +version = "1.15.0" +description = "Module for decorators, wrappers and monkey patching." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" +files = [ + {file = "wrapt-1.15.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ca1cccf838cd28d5a0883b342474c630ac48cac5df0ee6eacc9c7290f76b11c1"}, + {file = "wrapt-1.15.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:e826aadda3cae59295b95343db8f3d965fb31059da7de01ee8d1c40a60398b29"}, + {file = "wrapt-1.15.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5fc8e02f5984a55d2c653f5fea93531e9836abbd84342c1d1e17abc4a15084c2"}, + {file = "wrapt-1.15.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:96e25c8603a155559231c19c0349245eeb4ac0096fe3c1d0be5c47e075bd4f46"}, + {file = "wrapt-1.15.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:40737a081d7497efea35ab9304b829b857f21558acfc7b3272f908d33b0d9d4c"}, + {file = "wrapt-1.15.0-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:f87ec75864c37c4c6cb908d282e1969e79763e0d9becdfe9fe5473b7bb1e5f09"}, + {file = "wrapt-1.15.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:1286eb30261894e4c70d124d44b7fd07825340869945c79d05bda53a40caa079"}, + {file = "wrapt-1.15.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:493d389a2b63c88ad56cdc35d0fa5752daac56ca755805b1b0c530f785767d5e"}, + {file = "wrapt-1.15.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:58d7a75d731e8c63614222bcb21dd992b4ab01a399f1f09dd82af17bbfc2368a"}, + {file = "wrapt-1.15.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:21f6d9a0d5b3a207cdf7acf8e58d7d13d463e639f0c7e01d82cdb671e6cb7923"}, + {file = "wrapt-1.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ce42618f67741d4697684e501ef02f29e758a123aa2d669e2d964ff734ee00ee"}, + {file = "wrapt-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41d07d029dd4157ae27beab04d22b8e261eddfc6ecd64ff7000b10dc8b3a5727"}, + {file = "wrapt-1.15.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54accd4b8bc202966bafafd16e69da9d5640ff92389d33d28555c5fd4f25ccb7"}, + {file = "wrapt-1.15.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fbfbca668dd15b744418265a9607baa970c347eefd0db6a518aaf0cfbd153c0"}, + {file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:76e9c727a874b4856d11a32fb0b389afc61ce8aaf281ada613713ddeadd1cfec"}, + {file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e20076a211cd6f9b44a6be58f7eeafa7ab5720eb796975d0c03f05b47d89eb90"}, + {file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a74d56552ddbde46c246b5b89199cb3fd182f9c346c784e1a93e4dc3f5ec9975"}, + {file = "wrapt-1.15.0-cp310-cp310-win32.whl", hash = "sha256:26458da5653aa5b3d8dc8b24192f574a58984c749401f98fff994d41d3f08da1"}, + {file = "wrapt-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:75760a47c06b5974aa5e01949bf7e66d2af4d08cb8c1d6516af5e39595397f5e"}, + {file = "wrapt-1.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ba1711cda2d30634a7e452fc79eabcadaffedf241ff206db2ee93dd2c89a60e7"}, + {file = "wrapt-1.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:56374914b132c702aa9aa9959c550004b8847148f95e1b824772d453ac204a72"}, + {file = "wrapt-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a89ce3fd220ff144bd9d54da333ec0de0399b52c9ac3d2ce34b569cf1a5748fb"}, + {file = "wrapt-1.15.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bbe623731d03b186b3d6b0d6f51865bf598587c38d6f7b0be2e27414f7f214e"}, + {file = "wrapt-1.15.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3abbe948c3cbde2689370a262a8d04e32ec2dd4f27103669a45c6929bcdbfe7c"}, + {file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b67b819628e3b748fd3c2192c15fb951f549d0f47c0449af0764d7647302fda3"}, + {file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7eebcdbe3677e58dd4c0e03b4f2cfa346ed4049687d839adad68cc38bb559c92"}, + {file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:74934ebd71950e3db69960a7da29204f89624dde411afbfb3b4858c1409b1e98"}, + {file = "wrapt-1.15.0-cp311-cp311-win32.whl", hash = "sha256:bd84395aab8e4d36263cd1b9308cd504f6cf713b7d6d3ce25ea55670baec5416"}, + {file = "wrapt-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:a487f72a25904e2b4bbc0817ce7a8de94363bd7e79890510174da9d901c38705"}, + {file = "wrapt-1.15.0-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:4ff0d20f2e670800d3ed2b220d40984162089a6e2c9646fdb09b85e6f9a8fc29"}, + {file = "wrapt-1.15.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:9ed6aa0726b9b60911f4aed8ec5b8dd7bf3491476015819f56473ffaef8959bd"}, + {file = "wrapt-1.15.0-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:896689fddba4f23ef7c718279e42f8834041a21342d95e56922e1c10c0cc7afb"}, + {file = "wrapt-1.15.0-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:75669d77bb2c071333417617a235324a1618dba66f82a750362eccbe5b61d248"}, + {file = "wrapt-1.15.0-cp35-cp35m-win32.whl", hash = "sha256:fbec11614dba0424ca72f4e8ba3c420dba07b4a7c206c8c8e4e73f2e98f4c559"}, + {file = "wrapt-1.15.0-cp35-cp35m-win_amd64.whl", hash = "sha256:fd69666217b62fa5d7c6aa88e507493a34dec4fa20c5bd925e4bc12fce586639"}, + {file = "wrapt-1.15.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b0724f05c396b0a4c36a3226c31648385deb6a65d8992644c12a4963c70326ba"}, + {file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbeccb1aa40ab88cd29e6c7d8585582c99548f55f9b2581dfc5ba68c59a85752"}, + {file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:38adf7198f8f154502883242f9fe7333ab05a5b02de7d83aa2d88ea621f13364"}, + {file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:578383d740457fa790fdf85e6d346fda1416a40549fe8db08e5e9bd281c6a475"}, + {file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:a4cbb9ff5795cd66f0066bdf5947f170f5d63a9274f99bdbca02fd973adcf2a8"}, + {file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:af5bd9ccb188f6a5fdda9f1f09d9f4c86cc8a539bd48a0bfdc97723970348418"}, + {file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:b56d5519e470d3f2fe4aa7585f0632b060d532d0696c5bdfb5e8319e1d0f69a2"}, + {file = "wrapt-1.15.0-cp36-cp36m-win32.whl", hash = "sha256:77d4c1b881076c3ba173484dfa53d3582c1c8ff1f914c6461ab70c8428b796c1"}, + {file = "wrapt-1.15.0-cp36-cp36m-win_amd64.whl", hash = "sha256:077ff0d1f9d9e4ce6476c1a924a3332452c1406e59d90a2cf24aeb29eeac9420"}, + {file = "wrapt-1.15.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5c5aa28df055697d7c37d2099a7bc09f559d5053c3349b1ad0c39000e611d317"}, + {file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a8564f283394634a7a7054b7983e47dbf39c07712d7b177b37e03f2467a024e"}, + {file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780c82a41dc493b62fc5884fb1d3a3b81106642c5c5c78d6a0d4cbe96d62ba7e"}, + {file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e169e957c33576f47e21864cf3fc9ff47c223a4ebca8960079b8bd36cb014fd0"}, + {file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b02f21c1e2074943312d03d243ac4388319f2456576b2c6023041c4d57cd7019"}, + {file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f2e69b3ed24544b0d3dbe2c5c0ba5153ce50dcebb576fdc4696d52aa22db6034"}, + {file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d787272ed958a05b2c86311d3a4135d3c2aeea4fc655705f074130aa57d71653"}, + {file = "wrapt-1.15.0-cp37-cp37m-win32.whl", hash = "sha256:02fce1852f755f44f95af51f69d22e45080102e9d00258053b79367d07af39c0"}, + {file = "wrapt-1.15.0-cp37-cp37m-win_amd64.whl", hash = "sha256:abd52a09d03adf9c763d706df707c343293d5d106aea53483e0ec8d9e310ad5e"}, + {file = "wrapt-1.15.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cdb4f085756c96a3af04e6eca7f08b1345e94b53af8921b25c72f096e704e145"}, + {file = "wrapt-1.15.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:230ae493696a371f1dbffaad3dafbb742a4d27a0afd2b1aecebe52b740167e7f"}, + {file = "wrapt-1.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63424c681923b9f3bfbc5e3205aafe790904053d42ddcc08542181a30a7a51bd"}, + {file = "wrapt-1.15.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6bcbfc99f55655c3d93feb7ef3800bd5bbe963a755687cbf1f490a71fb7794b"}, + {file = "wrapt-1.15.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99f4309f5145b93eca6e35ac1a988f0dc0a7ccf9ccdcd78d3c0adf57224e62f"}, + {file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b130fe77361d6771ecf5a219d8e0817d61b236b7d8b37cc045172e574ed219e6"}, + {file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:96177eb5645b1c6985f5c11d03fc2dbda9ad24ec0f3a46dcce91445747e15094"}, + {file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5fe3e099cf07d0fb5a1e23d399e5d4d1ca3e6dfcbe5c8570ccff3e9208274f7"}, + {file = "wrapt-1.15.0-cp38-cp38-win32.whl", hash = "sha256:abd8f36c99512755b8456047b7be10372fca271bf1467a1caa88db991e7c421b"}, + {file = "wrapt-1.15.0-cp38-cp38-win_amd64.whl", hash = "sha256:b06fa97478a5f478fb05e1980980a7cdf2712015493b44d0c87606c1513ed5b1"}, + {file = "wrapt-1.15.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2e51de54d4fb8fb50d6ee8327f9828306a959ae394d3e01a1ba8b2f937747d86"}, + {file = "wrapt-1.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0970ddb69bba00670e58955f8019bec4a42d1785db3faa043c33d81de2bf843c"}, + {file = "wrapt-1.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76407ab327158c510f44ded207e2f76b657303e17cb7a572ffe2f5a8a48aa04d"}, + {file = "wrapt-1.15.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd525e0e52a5ff16653a3fc9e3dd827981917d34996600bbc34c05d048ca35cc"}, + {file = "wrapt-1.15.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d37ac69edc5614b90516807de32d08cb8e7b12260a285ee330955604ed9dd29"}, + {file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:078e2a1a86544e644a68422f881c48b84fef6d18f8c7a957ffd3f2e0a74a0d4a"}, + {file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2cf56d0e237280baed46f0b5316661da892565ff58309d4d2ed7dba763d984b8"}, + {file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7dc0713bf81287a00516ef43137273b23ee414fe41a3c14be10dd95ed98a2df9"}, + {file = "wrapt-1.15.0-cp39-cp39-win32.whl", hash = "sha256:46ed616d5fb42f98630ed70c3529541408166c22cdfd4540b88d5f21006b0eff"}, + {file = "wrapt-1.15.0-cp39-cp39-win_amd64.whl", hash = "sha256:eef4d64c650f33347c1f9266fa5ae001440b232ad9b98f1f43dfe7a79435c0a6"}, + {file = "wrapt-1.15.0-py3-none-any.whl", hash = "sha256:64b1df0f83706b4ef4cfb4fb0e4c2669100fd7ecacfb59e091fad300d4e04640"}, + {file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"}, +] + [[package]] name = "yarl" version = "1.9.2" @@ -2370,4 +2595,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "f94162f1217c3767f792902a9a45fec81275ae3a98f2809662bf3a3d574984e2" +content-hash = "1a98a080ce035b381521426c9d6f9f80e8656258beab6cdff95ea90cf6c77e85" diff --git a/server/pyproject.toml b/server/pyproject.toml index 3fb95f39..dc446796 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -21,6 +21,7 @@ structlog = "^23.1.0" uvicorn = {extras = ["standard"], version = "^0.23.1"} fastapi = "^0.100.1" sentry-sdk = {extras = ["fastapi"], version = "^1.29.2"} +httpx = "^0.24.1" [tool.poetry.group.dev.dependencies] @@ -28,7 +29,6 @@ black = "^23.7.0" [tool.poetry.group.client.dependencies] -httpx = "^0.24.1" pyaudio = "^0.2.13" stamina = "^23.1.0" @@ -38,6 +38,10 @@ pytest-aiohttp = "^1.0.4" pytest-asyncio = "^0.21.1" pytest = "^7.4.0" + +[tool.poetry.group.aws.dependencies] +aioboto3 = "^11.2.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/server/reflector/llm/__init__.py b/server/reflector/llm/__init__.py index fddf3919..f0dda3b6 100644 --- a/server/reflector/llm/__init__.py +++ b/server/reflector/llm/__init__.py @@ -1,3 +1 @@ from .base import LLM # noqa: F401 -from . import llm_oobagooda # noqa: F401 -from . import llm_openai # noqa: F401 diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 031e38aa..cc5b7245 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -1,6 +1,7 @@ from reflector.logger import logger from reflector.settings import settings -import asyncio +from reflector.utils.retry import retry +import importlib import json import re @@ -13,7 +14,7 @@ class LLM: cls._registry[name] = klass @classmethod - def instance(cls): + def get_instance(cls, name=None): """ Return an instance depending on the settings. Settings used: @@ -21,22 +22,19 @@ class LLM: - `LLM_BACKEND`: key of the backend, defaults to `oobagooda` - `LLM_URL`: url of the backend """ - return cls._registry[settings.LLM_BACKEND]() + if name is None: + name = settings.LLM_BACKEND + if name not in cls._registry: + module_name = f"reflector.llm.llm_{name}" + importlib.import_module(module_name) + return cls._registry[name]() - async def generate( - self, prompt: str, retry_count: int = 5, retry_interval: int = 1, **kwargs - ) -> dict: - while retry_count > 0: - try: - result = await self._generate(prompt=prompt, **kwargs) - break - except Exception: - logger.exception("Failed to call llm") - retry_count -= 1 - await asyncio.sleep(retry_interval) - - if retry_count == 0: - raise Exception("Failed to call llm after retrying") + async def generate(self, prompt: str, **kwargs) -> dict: + try: + result = await retry(self._generate)(prompt=prompt, **kwargs) + except Exception: + logger.exception("Failed to call llm after retrying") + raise if isinstance(result, str): result = self._parse_json(result) diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py new file mode 100644 index 00000000..0a0bfc93 --- /dev/null +++ b/server/reflector/llm/llm_banana.py @@ -0,0 +1,41 @@ +from reflector.llm.base import LLM +from reflector.settings import settings +from reflector.utils.retry import retry +import httpx + + +class BananaLLM(LLM): + def __init__(self): + super().__init__() + self.timeout = settings.LLM_TIMEOUT + self.headers = { + "X-Banana-API-Key": settings.LLM_BANANA_API_KEY, + "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, + } + + async def _generate(self, prompt: str, **kwargs): + async with httpx.AsyncClient() as client: + response = await retry(client.post)( + settings.LLM_URL, + headers=self.headers, + json={"prompt": prompt}, + timeout=self.timeout, + ) + response.raise_for_status() + text = response.json()["text"] + text = text[len(prompt) :] # remove prompt + return text + + +LLM.register("banana", BananaLLM) + +if __name__ == "__main__": + + async def main(): + llm = BananaLLM() + result = await llm.generate("Hello, my name is") + print(result) + + import asyncio + + asyncio.run(main()) diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py index 9b792009..339e5633 100644 --- a/server/reflector/processors/audio_transcript_auto.py +++ b/server/reflector/processors/audio_transcript_auto.py @@ -1,19 +1,38 @@ from reflector.processors.base import Processor from reflector.processors.audio_transcript import AudioTranscriptProcessor -from reflector.processors.audio_transcript_whisper import ( - AudioTranscriptWhisperProcessor, -) from reflector.processors.types import AudioFile +from reflector.settings import settings +import importlib class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): - BACKENDS = { - "whisper": AudioTranscriptWhisperProcessor, - } - BACKEND_DEFAULT = "whisper" + _registry = {} - def __init__(self, backend=None, **kwargs): - self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]() + @classmethod + def register(cls, name, kclass): + cls._registry[name] = kclass + + @classmethod + def get_instance(cls, name): + if name not in cls._registry: + module_name = f"reflector.processors.audio_transcript_{name}" + importlib.import_module(module_name) + + # gather specific configuration for the processor + # search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` + config = {} + name_upper = name.upper() + settings_prefix = "TRANSCRIPT_" + config_prefix = f"{settings_prefix}{name_upper}_" + for key, value in settings: + if key.startswith(config_prefix): + config_name = key[len(settings_prefix) :].lower() + config[config_name] = value + + return cls._registry[name](**config) + + def __init__(self, **kwargs): + self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND) super().__init__(**kwargs) def connect(self, processor: Processor): diff --git a/server/reflector/processors/audio_transcript_banana.py b/server/reflector/processors/audio_transcript_banana.py new file mode 100644 index 00000000..af8f647d --- /dev/null +++ b/server/reflector/processors/audio_transcript_banana.py @@ -0,0 +1,85 @@ +""" +Implementation using the GPU service from banana. + +API will be a POST request to TRANSCRIPT_URL: + +```json +{ + "audio_url": "https://...", + "audio_ext": "wav", + "timestamp": 123.456 + "language": "en" +} +``` + +""" + +from reflector.processors.audio_transcript import AudioTranscriptProcessor +from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor +from reflector.processors.types import AudioFile, Transcript, Word +from reflector.settings import settings +from reflector.storage import Storage +from reflector.utils.retry import retry +from pathlib import Path +import httpx + + +class AudioTranscriptBananaProcessor(AudioTranscriptProcessor): + def __init__(self, banana_api_key: str, banana_model_key: str): + super().__init__() + self.transcript_url = settings.TRANSCRIPT_URL + self.timeout = settings.TRANSCRIPT_TIMEOUT + self.storage = Storage.get_instance( + settings.TRANSCRIPT_STORAGE_BACKEND, "TRANSCRIPT_STORAGE_" + ) + self.headers = { + "X-Banana-API-Key": banana_api_key, + "X-Banana-Model-Key": banana_model_key, + } + + async def _transcript(self, data: AudioFile): + async with httpx.AsyncClient() as client: + print(f"Uploading audio {data.path.name} to S3") + url = await self._upload_file(data.path) + + print(f"Try to transcribe audio {data.path.name}") + request_data = { + "audio_url": url, + "audio_ext": data.path.suffix[1:], + "timestamp": float(round(data.timestamp, 2)), + } + response = await retry(client.post)( + self.transcript_url, + json=request_data, + headers=self.headers, + timeout=self.timeout, + ) + + print(f"Transcript response: {response.status_code} {response.content}") + response.raise_for_status() + result = response.json() + transcript = Transcript( + text=result["text"], + words=[ + Word(text=word["text"], start=word["start"], end=word["end"]) + for word in result["words"] + ], + ) + + # remove audio file from S3 + await self._delete_file(data.path) + + return transcript + + @retry + async def _upload_file(self, path: Path) -> str: + upload_result = await self.storage.put_file(path.name, open(path, "rb")) + return upload_result.url + + @retry + async def _delete_file(self, path: Path): + await self.storage.delete_file(path.name) + return True + + +AudioTranscriptAutoProcessor.register("banana", AudioTranscriptBananaProcessor) diff --git a/server/reflector/processors/audio_transcript_whisper.py b/server/reflector/processors/audio_transcript_whisper.py index 0b768543..972c636a 100644 --- a/server/reflector/processors/audio_transcript_whisper.py +++ b/server/reflector/processors/audio_transcript_whisper.py @@ -1,4 +1,5 @@ from reflector.processors.audio_transcript import AudioTranscriptProcessor +from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor from reflector.processors.types import AudioFile, Transcript, Word from faster_whisper import WhisperModel @@ -40,3 +41,6 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor): ) return transcript + + +AudioTranscriptAutoProcessor.register("whisper", AudioTranscriptWhisperProcessor) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index b602d61e..cb199825 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -28,7 +28,7 @@ class TranscriptTopicDetectorProcessor(Processor): super().__init__(**kwargs) self.transcript = None self.min_transcript_length = min_transcript_length - self.llm = LLM.instance() + self.llm = LLM.get_instance() async def _push(self, data: Transcript): if self.transcript is None: diff --git a/server/reflector/settings.py b/server/reflector/settings.py index fe1243bf..3b042462 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -26,8 +26,29 @@ class Settings(BaseSettings): AUDIO_SAMPLING_WIDTH: int = 2 AUDIO_BUFFER_SIZE: int = 256 * 960 + # Audio Transcription + # backends: whisper, banana + TRANSCRIPT_BACKEND: str = "whisper" + TRANSCRIPT_URL: str | None = None + TRANSCRIPT_TIMEOUT: int = 90 + + # Audio transcription banana.dev configuration + TRANSCRIPT_BANANA_API_KEY: str | None = None + TRANSCRIPT_BANANA_MODEL_KEY: str | None = None + + # Audio transcription storage + TRANSCRIPT_STORAGE_BACKEND: str = "aws" + + # Storage configuration for AWS + TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket/chunks" + TRANSCRIPT_STORAGE_AWS_REGION: str = "us-east-1" + TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None + TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None + # LLM LLM_BACKEND: str = "oobagooda" + + # LLM common configuration LLM_URL: str | None = None LLM_HOST: str = "localhost" LLM_PORT: int = 7860 @@ -38,11 +59,9 @@ class Settings(BaseSettings): LLM_MAX_TOKENS: int = 1024 LLM_TEMPERATURE: float = 0.7 - # Storage - STORAGE_BACKEND: str = "aws" - STORAGE_AWS_ACCESS_KEY: str = "" - STORAGE_AWS_SECRET_KEY: str = "" - STORAGE_AWS_BUCKET: str = "" + # LLM Banana configuration + LLM_BANANA_API_KEY: str | None = None + LLM_BANANA_MODEL_KEY: str | None = None # Sentry SENTRY_DSN: str | None = None diff --git a/server/reflector/storage/__init__.py b/server/reflector/storage/__init__.py new file mode 100644 index 00000000..fd4c72f0 --- /dev/null +++ b/server/reflector/storage/__init__.py @@ -0,0 +1 @@ +from .base import Storage # noqa diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py new file mode 100644 index 00000000..5cdafdbf --- /dev/null +++ b/server/reflector/storage/base.py @@ -0,0 +1,47 @@ +from pydantic import BaseModel +from reflector.settings import settings +import importlib + + +class FileResult(BaseModel): + filename: str + url: str + + +class Storage: + _registry = {} + CONFIG_SETTINGS = [] + + @classmethod + def register(cls, name, kclass): + cls._registry[name] = kclass + + @classmethod + def get_instance(cls, name, settings_prefix=""): + if name not in cls._registry: + module_name = f"reflector.storage.storage_{name}" + importlib.import_module(module_name) + + # gather specific configuration for the processor + # search `TRANSCRIPT_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy` + config = {} + name_upper = name.upper() + config_prefix = f"{settings_prefix}{name_upper}_" + for key, value in settings: + if key.startswith(config_prefix): + config_name = key[len(settings_prefix) :].lower() + config[config_name] = value + + return cls._registry[name](**config) + + async def put_file(self, filename: str, data: bytes) -> FileResult: + return await self._put_file(filename, data) + + async def _put_file(self, filename: str, data: bytes) -> FileResult: + raise NotImplementedError + + async def delete_file(self, filename: str): + return await self._delete_file(filename) + + async def _delete_file(self, filename: str): + raise NotImplementedError diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py new file mode 100644 index 00000000..09a9c383 --- /dev/null +++ b/server/reflector/storage/storage_aws.py @@ -0,0 +1,67 @@ +import aioboto3 +from reflector.storage.base import Storage, FileResult +from reflector.logger import logger + + +class AwsStorage(Storage): + def __init__( + self, + aws_access_key_id: str, + aws_secret_access_key: str, + aws_bucket_name: str, + aws_region: str, + ): + if not aws_access_key_id: + raise ValueError("Storage `aws_storage` require `aws_access_key_id`") + if not aws_secret_access_key: + raise ValueError("Storage `aws_storage` require `aws_secret_access_key`") + if not aws_bucket_name: + raise ValueError("Storage `aws_storage` require `aws_bucket_name`") + if not aws_region: + raise ValueError("Storage `aws_storage` require `aws_region`") + + super().__init__() + self.aws_bucket_name = aws_bucket_name + self.aws_folder = "" + if "/" in aws_bucket_name: + self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1) + self.session = aioboto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region, + ) + self.base_url = f"https://{aws_bucket_name}.s3.amazonaws.com/" + + async def _put_file(self, filename: str, data: bytes) -> FileResult: + bucket = self.aws_bucket_name + folder = self.aws_folder + logger.info(f"Uploading {filename} to S3 {bucket}/{folder}") + s3filename = f"{folder}/{filename}" if folder else filename + async with self.session.client("s3") as client: + await client.put_object( + Bucket=bucket, + Key=s3filename, + Body=data, + ) + + presigned_url = await client.generate_presigned_url( + "get_object", + Params={"Bucket": bucket, "Key": s3filename}, + ExpiresIn=3600, + ) + + return FileResult( + filename=filename, + url=presigned_url, + ) + + async def _delete_file(self, filename: str): + bucket = self.aws_bucket_name + folder = self.aws_folder + logger.info(f"Deleting {filename} from S3 {bucket}/{folder}") + s3filename = f"{folder}/{filename}" if folder else filename + async with self.session.client("s3") as client: + await client.delete_object(Bucket=bucket, Key=s3filename) + + +Storage.register("aws", AwsStorage) diff --git a/server/reflector/tools/process.py b/server/reflector/tools/process.py index 071907ea..85febaff 100644 --- a/server/reflector/tools/process.py +++ b/server/reflector/tools/process.py @@ -12,7 +12,7 @@ from reflector.processors import ( import asyncio -async def process_audio_file(filename, event_callback): +async def process_audio_file(filename, event_callback, only_transcript=False): async def on_transcript(data): await event_callback("transcript", data) @@ -22,15 +22,21 @@ async def process_audio_file(filename, event_callback): async def on_summary(data): await event_callback("summary", data) - # transcription output - pipeline = Pipeline( + # build pipeline for audio processing + processors = [ AudioChunkerProcessor(), AudioMergeProcessor(), AudioTranscriptAutoProcessor.as_threaded(), TranscriptLinerProcessor(callback=on_transcript), - TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), - TranscriptFinalSummaryProcessor.as_threaded(callback=on_summary), - ) + ] + if not only_transcript: + processors += [ + TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), + TranscriptFinalSummaryProcessor.as_threaded(callback=on_summary), + ] + + # transcription output + pipeline = Pipeline(*processors) pipeline.describe() # start processing audio @@ -52,6 +58,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("source", help="Source file (mp3, wav, mp4...)") + parser.add_argument("--only-transcript", "-t", action="store_true") args = parser.parse_args() async def event_callback(event, data): @@ -62,4 +69,8 @@ if __name__ == "__main__": elif event == "summary": print(f"Summary: {data}") - asyncio.run(process_audio_file(args.source, event_callback)) + asyncio.run( + process_audio_file( + args.source, event_callback, only_transcript=args.only_transcript + ) + ) diff --git a/server/reflector/utils/retry.py b/server/reflector/utils/retry.py new file mode 100644 index 00000000..0a270f37 --- /dev/null +++ b/server/reflector/utils/retry.py @@ -0,0 +1,29 @@ +from reflector.logger import logger +import asyncio + + +def retry(fn): + async def decorated(*args, **kwargs): + retry_max = kwargs.pop("retry_max", 5) + retry_delay = kwargs.pop("retry_delay", 2) + retry_ignore_exc_types = kwargs.pop("retry_ignore_exc_types", ()) + result = None + attempt = 0 + last_exception = None + for attempt in range(retry_max): + try: + result = await fn(*args, **kwargs) + if result: + return result + except retry_ignore_exc_types as e: + last_exception = e + logger.debug( + f"Retrying {fn} - in {retry_delay} seconds " + f"- attempt {attempt + 1}/{retry_max}" + ) + await asyncio.sleep(retry_delay) + if last_exception is not None: + raise type(last_exception) from last_exception + return result + + return decorated diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index ab836550..fa8bf31a 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -12,6 +12,7 @@ async def test_basic_process(event_loop): # use an LLM test backend settings.LLM_BACKEND = "test" + settings.TRANSCRIPT_BACKEND = "whisper" class LLMTest(LLM): async def _generate(self, prompt: str, **kwargs) -> str: