diff --git a/server/poetry.lock b/server/poetry.lock index 71206cba..77242591 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1851,6 +1851,24 @@ pytest = ">=7.0.0" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] +[[package]] +name = "pytest-httpx" +version = "0.23.1" +description = "Send responses to httpx." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pytest_httpx-0.23.1-py3-none-any.whl", hash = "sha256:ba38a9e6c685d3cf6197551a79bf7e41f8bbc57a6d1de65b537f77e87f56ecd3"}, + {file = "pytest_httpx-0.23.1.tar.gz", hash = "sha256:cfed19eb8b13cbdf464bbb1c4ef88717d88d42334aa9ce516e56e46975c77f74"}, +] + +[package.dependencies] +httpx = "==0.24.*" +pytest = ">=6.0,<8.0" + +[package.extras] +testing = ["pytest-asyncio (==0.21.*)", "pytest-cov (==4.*)"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -2042,26 +2060,6 @@ files = [ {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, ] -[[package]] -name = "stamina" -version = "23.1.0" -description = "Production-grade retries made easy." -optional = false -python-versions = ">=3.8" -files = [ - {file = "stamina-23.1.0-py3-none-any.whl", hash = "sha256:850de8c2c2469aabf42a4c02e7372eaa12c2eced78f2bfa34162b8676c2846e5"}, - {file = "stamina-23.1.0.tar.gz", hash = "sha256:b16ce3d52d658aa75db813fc6a6661b770abfea915f72cda48e325f2a7854786"}, -] - -[package.dependencies] -tenacity = "*" - -[package.extras] -dev = ["nox", "prometheus-client", "stamina[tests,typing]", "structlog", "tomli"] -docs = ["furo", "myst-parser", "prometheus-client", "sphinx", "sphinx-notfound-page", "structlog"] -tests = ["pytest", "pytest-asyncio"] -typing = ["mypy (>=1.4)"] - [[package]] name = "starlette" version = "0.27.0" @@ -2110,20 +2108,6 @@ files = [ [package.dependencies] mpmath = ">=0.19" -[[package]] -name = "tenacity" -version = "8.2.2" -description = "Retry code until it succeeds" -optional = false -python-versions = ">=3.6" -files = [ - {file = "tenacity-8.2.2-py3-none-any.whl", hash = "sha256:2f277afb21b851637e8f52e6a613ff08734c347dc19ade928e519d7d2d8569b0"}, - {file = "tenacity-8.2.2.tar.gz", hash = "sha256:43af037822bd0029025877f3b2d97cc4d7bb0c2991000a3d59d71517c5c969e0"}, -] - -[package.extras] -doc = ["reno", "sphinx", "tornado (>=4.5)"] - [[package]] name = "tokenizers" version = "0.13.3" @@ -2595,4 +2579,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "1a98a080ce035b381521426c9d6f9f80e8656258beab6cdff95ea90cf6c77e85" +content-hash = "b6097887e0343a553bec5519aec6ecf345796e27d4a0f0f4abf8cd51e56a24eb" diff --git a/server/pyproject.toml b/server/pyproject.toml index dc446796..bd10f796 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -30,13 +30,13 @@ black = "^23.7.0" [tool.poetry.group.client.dependencies] pyaudio = "^0.2.13" -stamina = "^23.1.0" [tool.poetry.group.tests.dependencies] pytest-aiohttp = "^1.0.4" pytest-asyncio = "^0.21.1" pytest = "^7.4.0" +pytest-httpx = "^0.23.1" [tool.poetry.group.aws.dependencies] diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index 0a0bfc93..d6a0fa07 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -20,6 +20,7 @@ class BananaLLM(LLM): headers=self.headers, json={"prompt": prompt}, timeout=self.timeout, + retry_timeout=300, # as per their sdk ) response.raise_for_status() text = response.json()["text"] diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index 03189afc..dd438704 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -21,7 +21,6 @@ class OpenAILLM(LLM): "Authorization": f"Bearer {self.openai_key}", } - async with httpx.AsyncClient(timeout=self.timeout) as client: response = await client.post( self.openai_url, diff --git a/server/reflector/utils/retry.py b/server/reflector/utils/retry.py index 0a270f37..9d483d3a 100644 --- a/server/reflector/utils/retry.py +++ b/server/reflector/utils/retry.py @@ -1,27 +1,82 @@ from reflector.logger import logger +from time import monotonic +from httpx import HTTPStatusError, Response +from random import random import asyncio +class RetryException(Exception): + pass + + +class RetryTimeoutException(RetryException): + pass + + +class RetryHTTPException(RetryException): + pass + + def retry(fn): async def decorated(*args, **kwargs): - retry_max = kwargs.pop("retry_max", 5) - retry_delay = kwargs.pop("retry_delay", 2) - retry_ignore_exc_types = kwargs.pop("retry_ignore_exc_types", ()) + retry_attempts = kwargs.pop("retry_attempts", None) + retry_timeout = kwargs.pop("retry_timeout", 60) + retry_backoff_interval = kwargs.pop("retry_backoff_interval", 0.1) + retry_jitter = kwargs.pop("retry_jitter", 0.1) + retry_backoff_max = kwargs.pop("retry_backoff_max", 3) + retry_httpx_status_stop = kwargs.pop( + "retry_httpx_status_stop", + ( + 401, # auth issue + 404, # not found + 413, # payload too large + 418, # teapot + ), + ) + retry_ignore_exc_types = kwargs.pop("retry_ignore_exc_types", (Exception,)) + result = None - attempt = 0 last_exception = None - for attempt in range(retry_max): + attempts = 0 + start = monotonic() + fn_name = fn.__name__ + + # goal: retry until timeout + while True: + if monotonic() - start > retry_timeout: + raise RetryTimeoutException() + + jitter = random() * retry_jitter + retry_backoff_interval = min( + retry_backoff_interval * 2 + jitter, retry_backoff_max + ) + try: result = await fn(*args, **kwargs) + if isinstance(result, Response): + result.raise_for_status() if result: return result + except HTTPStatusError as e: + status_code = e.response.status_code + logger.debug(f"HTTP status {status_code} - {e}") + if status_code in retry_httpx_status_stop: + message = f"HTTP status {status_code} is in retry_httpx_status_stop" + raise RetryHTTPException(message) from e except retry_ignore_exc_types as e: last_exception = e + logger.debug( - f"Retrying {fn} - in {retry_delay} seconds " - f"- attempt {attempt + 1}/{retry_max}" + f"Retrying {fn_name} - in {retry_backoff_interval:.1f}s " + f"({monotonic() - start:.1f}s / {retry_timeout:.1f}s)" ) - await asyncio.sleep(retry_delay) + attempts += 1 + + if retry_attempts is not None and attempts >= retry_attempts: + raise RetryException(f"Retry attempts exceeded: {retry_attempts}") + + await asyncio.sleep(retry_backoff_interval) + if last_exception is not None: raise type(last_exception) from last_exception return result diff --git a/server/tests/test_retry_decorator.py b/server/tests/test_retry_decorator.py new file mode 100644 index 00000000..22729eac --- /dev/null +++ b/server/tests/test_retry_decorator.py @@ -0,0 +1,55 @@ +import pytest +import httpx +from reflector.utils.retry import ( + retry, + RetryTimeoutException, + RetryHTTPException, + RetryException, +) + + +@pytest.mark.asyncio +async def test_retry_httpx(httpx_mock): + # this code should be force a retry + httpx_mock.add_response(status_code=500) + async with httpx.AsyncClient() as client: + with pytest.raises(RetryTimeoutException): + await retry(client.get)("https://test_url", retry_timeout=0.1) + + # but if we add it in the retry_httpx_status_stop, it should not retry + async with httpx.AsyncClient() as client: + with pytest.raises(RetryHTTPException): + await retry(client.get)( + "https://test_url", retry_timeout=5, retry_httpx_status_stop=[500] + ) + + +@pytest.mark.asyncio +async def test_retry_normal(): + left = 3 + + async def retry_before_success(): + nonlocal left + if left > 0: + left -= 1 + raise Exception("test") + return True + + result = await retry(retry_before_success)() + assert result is True + assert left == 0 + + +@pytest.mark.asyncio +async def test_retry_max_attempts(): + left = 3 + + async def retry_before_success(): + nonlocal left + if left > 0: + left -= 1 + raise Exception("test") + return True + + with pytest.raises(RetryException): + await retry(retry_before_success)(retry_attempts=2)