diff --git a/mokelumne/dags/gen_llm_image_descriptions.py b/mokelumne/dags/gen_llm_image_descriptions.py index 8a88aff..f0b9549 100644 --- a/mokelumne/dags/gen_llm_image_descriptions.py +++ b/mokelumne/dags/gen_llm_image_descriptions.py @@ -12,6 +12,7 @@ from pathlib import Path from shutil import copyfile from typing import List +import requests from airflow.exceptions import AirflowFailException from airflow.providers.smtp.operators.smtp import EmailOperator from airflow.sdk import dag, task, task_group, Param, get_current_context @@ -27,6 +28,7 @@ from mokelumne.plugins.static_files.helpers import static_files_run_dir from mokelumne.util.storage import run_dir from mokelumne.util.tind_csv_writer import TindCsvWriter, is_single_image_record +from tind_client.errors import TooManyRequestsError logger = logging.getLogger(__name__) @@ -110,6 +112,12 @@ description_md="The maximum height for the fetched image. Must be less than 8000px." ), }, + default_args={ + "retries": 3, + "retry_delay": 3, + "retry_exponential_backoff": True, + "max_retry_delay": 600, # 10 minutes + }, tags=["batch-image", "csv", "generate-descriptions", "llm", "process",], ) def gen_llm_image_descriptions(): @@ -217,6 +225,13 @@ def fetch_image_to_record_directory(run_id: str, fetcher: ImageFetcher, ) path = str(fetcher.fetch_one_image_for_record(tind_id, run_id)) + except TooManyRequestsError as ex: + ti = context["ti"] + if ti.try_number <= ti.max_tries: + logger.warning("TIND API returned 429; marking record for retry") + raise + logger.warning("TIND API returned 429 on final attempt; marking record as failed") + return RunStatus(tind_id=tind_id, status="failed", description="TIND API too busy, try again later", path="") except Exception as ex: # pylint: disable=broad-exception-caught logger.warning("Fetcher encountered exception", exc_info=ex) return RunStatus(tind_id=tind_id, status="failed", description=str(ex), path="") @@ -383,7 +398,7 @@ def collate_csvs(output_dir_str: str) -> str: return timestamp @task - def generate_summary(output_dir_str: str, timestamp: str) -> str: + def generate_summary(output_dir_str: str, timestamp: str) -> str: # pylint: disable=too-many-locals """Generate a summary of the files in the collated path.""" def count_success_fail_of_csv(csv_file: Path, success: str) -> tuple[int, int, int]: """Count the success and failure rows for the given CSV.""" diff --git a/pyproject.toml b/pyproject.toml index 1301cbe..8722472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ skip-checking-raises = true style = "sphinx" [tool.pylint.messages_control] -disable = ["expression-not-assigned"] +disable = ["expression-not-assigned", "pointless-statement", "too-many-statements"] [tool.pytest] minversion = "9.0" diff --git a/requirements.txt b/requirements.txt index 9995863..199bee6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -591,6 +591,7 @@ fastapi==0.135.3 \ --hash=sha256:9b0f590c813acd13d0ab43dd8494138eb58e484bfac405db1f3187cfc5810d98 \ --hash=sha256:bd6d7caf1a2bdd8d676843cdcd2287729572a1ef524fc4d65c17ae002a1be654 # via + # mokelumne (pyproject.toml) # apache-airflow-core # cadwyn fastapi-cli==0.0.24 \ @@ -850,6 +851,7 @@ httpx==0.28.1 \ --hash=sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc \ --hash=sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad # via + # mokelumne (pyproject.toml) # apache-airflow-core # apache-airflow-task-sdk # fastapi @@ -1998,9 +2000,9 @@ python-slugify==8.0.4 \ # via # apache-airflow-core # pytest-playwright -python-tind-client==0.2.3 \ - --hash=sha256:8d5286b1af0df3a87ffc4c80a7199663fb02983dbe2b576601ae76cf50e0e915 \ - --hash=sha256:dfeb0590267c6af5f60832eedb4a7fa4794844cc44bb75066febe9d9698599cc +python-tind-client==0.2.4 \ + --hash=sha256:52b4a44c33963f682f66d107940766a76c8cd4e3c829d65dd609ccf5f6641b5f \ + --hash=sha256:c9699c47223491bcb492d457542c9abfa8cabd1a585fffd23b2e3b55639d8166 # via mokelumne (pyproject.toml) pytz==2026.1.post1 \ --hash=sha256:3378dde6a0c3d26719182142c56e60c7f9af7e968076f31aae569d72a0358ee1 \ diff --git a/test/tests/test_gen_llm_image_descriptions.py b/test/tests/test_gen_llm_image_descriptions.py new file mode 100644 index 0000000..194355d --- /dev/null +++ b/test/tests/test_gen_llm_image_descriptions.py @@ -0,0 +1,53 @@ +"""Tests for gen_llm_image_descriptions DAG.""" +# pylint: disable=redefined-outer-name + +from unittest.mock import patch, MagicMock +import pytest + +from airflow.dag_processing.dagbag import DagBag +from pathlib import Path +from tind_client.errors import TooManyRequestsError + +dag_dir = Path(__file__).resolve().parent.parent.parent / "mokelumne/dags" + + +@pytest.fixture(scope="module") +def fetch_fn(): + """Fixture to get the fetch function from the DAG.""" + dagbag = DagBag(dag_folder=dag_dir.resolve(), include_examples=False) + dag = dagbag.get_dag("gen_llm_image_descriptions") + return dag.get_task("fetch_images.fetch_image_to_record_directory").python_callable + + +def _mock_context(try_number: int, max_tries: int) -> dict: + mock_ti = MagicMock() + mock_ti.try_number = try_number + mock_ti.max_tries = max_tries + return {"params": {"max_width": 8000, "max_height": 8000}, "run_id": "test", "ti": mock_ti} + + +def test_429_causes_task_retry(fetch_fn): + """If retries remain, a TindClient's TooManyRequestsError (429) triggers a retry.""" + mock_fetcher = MagicMock() + mock_fetcher.get_metadata_for_record.side_effect = TooManyRequestsError() + + with patch( + f"{fetch_fn.__module__}.get_current_context", + return_value=_mock_context(try_number=1, max_tries=3), + ): + with pytest.raises(TooManyRequestsError): + fetch_fn("test_run", mock_fetcher, "12345") + + +def test_429_on_final_attempt_returns_failed_status(fetch_fn): + """If last retry gets a TooManyRequestsError, the task returns a failed status.""" + mock_fetcher = MagicMock() + mock_fetcher.get_metadata_for_record.side_effect = TooManyRequestsError() + + with patch( + f"{fetch_fn.__module__}.get_current_context", + return_value=_mock_context(try_number=4, max_tries=3), + ): + result = fetch_fn("test_run", mock_fetcher, "12345") + assert result.tind_id == "12345" + assert result.status == "failed" diff --git a/test/unit/test_image_fetcher.py b/test/unit/test_image_fetcher.py index b871841..ec5bb70 100644 --- a/test/unit/test_image_fetcher.py +++ b/test/unit/test_image_fetcher.py @@ -4,6 +4,8 @@ from typing import Any from unittest.mock import Mock, call +from tind_client.errors import TooManyRequestsError + from mokelumne.util.image_fetcher import ImageFetcher, base64_size @@ -68,6 +70,13 @@ def download_image_from_record_sized(self, _record_id: str, _run_id: str, _width return str(FIXTURE_PATH / 'test3_scaled.jpg') +class MockTindHookWith429(MockTindHook): + """A mock TindHook that simulates a 429 Too Many Requests response.""" + def get_file_metadata(self, _record_id: str) -> list[dict[str, Any]]: + """Raise a Too Many Requests error.""" + raise TooManyRequestsError() + + def fetch_factory(tind_mock: MockTindHook, **kwargs) -> ImageFetcher: """Create an ImageFetcher with a mocked TIND fetcher client.