diff --git a/backend/apps/common/api/__init__.py b/backend/apps/common/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/apps/common/api/internal/__init__.py b/backend/apps/common/api/internal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/apps/common/api/internal/dataloaders/__init__.py b/backend/apps/common/api/internal/dataloaders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/apps/common/api/internal/dataloaders/utils.py b/backend/apps/common/api/internal/dataloaders/utils.py new file mode 100644 index 0000000000..361331c577 --- /dev/null +++ b/backend/apps/common/api/internal/dataloaders/utils.py @@ -0,0 +1,29 @@ +"""Shared utilities for GraphQL dataloaders.""" + +from collections import defaultdict + +from django.db.models import QuerySet + + +async def results_by_keys[K, V]( + queryset: QuerySet, keys: list[K], key_field: str, value_field: str +) -> list[list[V]]: + """Map a grouped-results dict back to an ordered list matching ``keys``. + + Args: + queryset: The queryset to iterate over. + keys: A list of keys to map the results to, in the desired order. + key_field: The name of the attribute on each item that contains the key. + value_field: The name of the attribute on each item that contains the value. + + Returns: + A list of result-lists, one per key, in the same order as ``keys``. + + """ + mapping: dict[K, list[V]] = defaultdict(list) + async for item in queryset: + key = getattr(item, key_field) + value = getattr(item, value_field) + mapping[key].append(value) + + return [mapping.get(key, []) for key in keys] diff --git a/backend/apps/github/api/internal/dataloaders/__init__.py b/backend/apps/github/api/internal/dataloaders/__init__.py new file mode 100644 index 0000000000..cf577cdd7b --- /dev/null +++ b/backend/apps/github/api/internal/dataloaders/__init__.py @@ -0,0 +1,12 @@ +from strawberry.dataloader import DataLoader + +from apps.github.api.internal.dataloaders.interested_users import make_interested_users_loader + +INTERESTED_USERS_LOADER = "interested_users_loader" + + +def make_github_dataloaders() -> dict[str, DataLoader]: + """Return a dict of dataloader instances for GitHub API resolvers.""" + return { + INTERESTED_USERS_LOADER: make_interested_users_loader(), + } diff --git a/backend/apps/github/api/internal/dataloaders/interested_users.py b/backend/apps/github/api/internal/dataloaders/interested_users.py new file mode 100644 index 0000000000..b4579a0996 --- /dev/null +++ b/backend/apps/github/api/internal/dataloaders/interested_users.py @@ -0,0 +1,22 @@ +"""DataLoader for interested users per issue.""" + +from strawberry.dataloader import DataLoader + +from apps.common.api.internal.dataloaders.utils import results_by_keys +from apps.github.models.user import User +from apps.mentorship.models.issue_user_interest import IssueUserInterest + + +async def load_interested_users(issue_ids: list[int]) -> list[list[User]]: + """Batch-load interested users for the given issue IDs in a single query.""" + interests = ( + IssueUserInterest.objects.select_related("user__owasp_profile") + .filter(issue_id__in=issue_ids) + .order_by("user__login") + ) + return await results_by_keys(interests, issue_ids, key_field="issue_id", value_field="user") + + +def make_interested_users_loader() -> DataLoader: + """Return a per-request DataLoader instance.""" + return DataLoader(load_fn=load_interested_users) diff --git a/backend/apps/github/api/internal/nodes/issue.py b/backend/apps/github/api/internal/nodes/issue.py index c603887d5f..c7521b7f5e 100644 --- a/backend/apps/github/api/internal/nodes/issue.py +++ b/backend/apps/github/api/internal/nodes/issue.py @@ -8,11 +8,11 @@ from strawberry.types import Info from apps.common.utils import normalize_limit +from apps.github.api.internal.dataloaders import INTERESTED_USERS_LOADER from apps.github.api.internal.nodes.pull_request import PullRequestNode from apps.github.api.internal.nodes.user import UserNode from apps.github.models.issue import Issue from apps.github.models.pull_request import PullRequest -from apps.mentorship.models.issue_user_interest import IssueUserInterest from apps.mentorship.models.task import Task MERGED_PULL_REQUESTS_PREFETCH = Prefetch( @@ -79,20 +79,10 @@ def is_merged(self, root: Issue) -> bool: """Return True if this issue has at least one merged pull request.""" return bool(getattr(root, "merged_pull_requests", None)) - @strawberry_django.field( - prefetch_related=[ - Prefetch( - "participant_interests", - queryset=IssueUserInterest.objects.select_related("user__owasp_profile").order_by( - "user__login" - ), - to_attr="interests_users", - ) - ] - ) - def interested_users(self, root: Issue) -> list[UserNode]: + @strawberry_django.field + async def interested_users(self, root: Issue, info: Info) -> list[UserNode]: """Return all users who have expressed interest in this issue.""" - return [interest.user for interest in getattr(root, "interests_users", [])] + return await info.context.github_dataloaders[INTERESTED_USERS_LOADER].load(root.pk) @strawberry.field def task_deadline(self, root: Issue, info: Info) -> datetime | None: diff --git a/backend/apps/github/api/internal/queries/organization.py b/backend/apps/github/api/internal/queries/organization.py index 996b287b62..e43917157a 100644 --- a/backend/apps/github/api/internal/queries/organization.py +++ b/backend/apps/github/api/internal/queries/organization.py @@ -1,6 +1,7 @@ """GitHub organization GraphQL queries.""" import strawberry +import strawberry_django from apps.github.api.internal.nodes.organization import OrganizationNode from apps.github.models.organization import Organization @@ -10,8 +11,8 @@ class OrganizationQuery: """Organization queries.""" - @strawberry.field - def organization( + @strawberry_django.field + async def organization( self, *, login: str, @@ -26,6 +27,6 @@ def organization( """ try: - return Organization.objects.get(is_owasp_related_organization=True, login=login) + return await Organization.objects.aget(is_owasp_related_organization=True, login=login) except Organization.DoesNotExist: return None diff --git a/backend/apps/mentorship/api/internal/queries/module.py b/backend/apps/mentorship/api/internal/queries/module.py index d9d3924fa1..e1505d5987 100644 --- a/backend/apps/mentorship/api/internal/queries/module.py +++ b/backend/apps/mentorship/api/internal/queries/module.py @@ -3,6 +3,8 @@ import logging import strawberry +import strawberry_django +from asgiref.sync import sync_to_async from apps.mentorship.api.internal.nodes.module import ModuleNode from apps.mentorship.models import Module, Program @@ -14,17 +16,19 @@ class ModuleQuery: """Module queries.""" - @strawberry.field - def get_program_modules(self, info: strawberry.Info, program_key: str) -> list[ModuleNode]: + @strawberry_django.field + async def get_program_modules( + self, info: strawberry.Info, program_key: str + ) -> list[ModuleNode]: """Get all modules by program Key. Returns an empty list if program is not found.""" try: - program = Program.objects.get(key=program_key) + program = await Program.objects.aget(key=program_key) except Program.DoesNotExist: return [] - if program.status != Program.ProgramStatus.PUBLISHED and not program.user_has_access( - info.context.request.user - ): + if program.status != Program.ProgramStatus.PUBLISHED and not await sync_to_async( + program.user_has_access + )(info.context.request.user): return [] return ( diff --git a/backend/apps/mentorship/api/internal/queries/program.py b/backend/apps/mentorship/api/internal/queries/program.py index a7e47b0c6e..f7f73fd7ac 100644 --- a/backend/apps/mentorship/api/internal/queries/program.py +++ b/backend/apps/mentorship/api/internal/queries/program.py @@ -3,6 +3,7 @@ import logging import strawberry +import strawberry_django from django.db.models import Q from apps.common.utils import normalize_limit @@ -20,13 +21,13 @@ class ProgramQuery: """Program queries.""" - @strawberry.field - def get_program(self, info: strawberry.Info, program_key: str) -> ProgramNode | None: + @strawberry_django.field + async def get_program(self, info: strawberry.Info, program_key: str) -> ProgramNode | None: """Get a program by Key.""" try: - program = Program.objects.prefetch_related( + program = await Program.objects.prefetch_related( "admins__github_user", "admins__nest_user" - ).get(key=program_key) + ).aget(key=program_key) except Program.DoesNotExist: msg = f"Program with key '{program_key}' not found." logger.warning(msg, exc_info=True) diff --git a/backend/apps/owasp/api/internal/queries/stats.py b/backend/apps/owasp/api/internal/queries/stats.py index 21cdc9a617..7b1bab6657 100644 --- a/backend/apps/owasp/api/internal/queries/stats.py +++ b/backend/apps/owasp/api/internal/queries/stats.py @@ -1,12 +1,14 @@ """OWASP stats GraphQL queries.""" import strawberry +import strawberry_django from apps.common.utils import round_down from apps.github.models.user import User from apps.owasp.api.internal.nodes.stats import StatsNode from apps.owasp.models.chapter import Chapter from apps.owasp.models.project import Project +from apps.slack.constants import OWASP_WORKSPACE_ID from apps.slack.models.workspace import Workspace @@ -14,25 +16,22 @@ class StatsQuery: """Stats queries.""" - @strawberry.field - def stats_overview(self) -> StatsNode: + @strawberry_django.field + async def stats_overview(self) -> StatsNode: """Resolve stats overview.""" - active_projects_stats = Project.active_projects_count() - active_chapters_stats = Chapter.active_chapters_count() - contributors_stats = User.objects.count() - countries_stats = ( + active_projects_stats = await Project.active_projects.acount() + active_chapters_stats = await Chapter.active_chapters.acount() + contributors_stats = await User.objects.acount() + countries_stats = await ( Chapter.objects.filter(country__isnull=False) .exclude(country="") .values("country") .distinct() - .count() + .acount() ) - slack_workspace_stats = ( - workspace.total_members_count - if (workspace := Workspace.get_default_workspace()) - else 0 - ) + workspace = await Workspace.objects.filter(slack_workspace_id=OWASP_WORKSPACE_ID).afirst() + slack_workspace_stats = workspace.total_members_count if workspace else 0 return StatsNode( active_chapters_stats=round_down(active_chapters_stats, 10), diff --git a/backend/poetry.lock b/backend/poetry.lock index 7d1a49ea8b..281a1d34cf 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.3.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1874,7 +1874,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.3.6" +jsonschema-specifications = ">=2023.03.6" referencing = ">=0.28.4" rpds-py = ">=0.25.0" @@ -3544,6 +3544,25 @@ pygments = ">=2.7.2" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.10" +groups = ["test"] +files = [ + {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, + {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, +] + +[package.dependencies] +pytest = ">=8.2,<10" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "7.1.0" @@ -4097,10 +4116,10 @@ files = [ ] [package.dependencies] -botocore = ">=1.37.4,<2.0a0" +botocore = ">=1.37.4,<2.0a.0" [package.extras] -crt = ["botocore[crt] (>=1.37.4,<2.0a0)"] +crt = ["botocore[crt] (>=1.37.4,<2.0a.0)"] [[package]] name = "schemathesis" @@ -5134,9 +5153,9 @@ files = [ ] [package.extras] -cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and python_version < \"3.14\"", "cffi (>=2.0.0b0) ; platform_python_implementation != \"PyPy\" and python_version >= \"3.14\""] +cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and python_version < \"3.14\"", "cffi (>=2.0.0b) ; platform_python_implementation != \"PyPy\" and python_version >= \"3.14\""] [metadata] lock-version = "2.1" python-versions = "^3.13" -content-hash = "c357cbbdb451c13ac10249e313f95a239dd1d41b5cb15201d056223b4c302ce7" +content-hash = "3f712dc0290347b18e75fc6cd15aa2faec11113fd7a3087c7797623ea92fdce3" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 07e7d912ed..68e4ab90e9 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -47,6 +47,7 @@ dependencies.thefuzz = "^0.22.1" dependencies.pyparsing = "^3.2.3" group.fuzz.dependencies.schemathesis = "^4.10.2" group.test.dependencies.pytest = "^9.0.1" +group.test.dependencies.pytest-asyncio = "^1.3.0" group.test.dependencies.pytest-cov = "^7.0" group.test.dependencies.pytest-django = "^4.5" group.test.dependencies.pytest-mock = "^3.0" @@ -156,6 +157,7 @@ ini_options.filterwarnings = [ "ignore::DeprecationWarning:xdist", "ignore::pydantic.warnings.PydanticDeprecatedSince20", ] +ini_options.asyncio_mode = "auto" ini_options.log_level = "INFO" [tool.coverage] diff --git a/backend/settings/graphql.py b/backend/settings/graphql.py index 4f7fd8691a..12fc36c38b 100644 --- a/backend/settings/graphql.py +++ b/backend/settings/graphql.py @@ -2,6 +2,8 @@ import strawberry from django.conf import settings +from django.http import HttpRequest, HttpResponse +from strawberry.django.views import AsyncGraphQLView from strawberry.extensions import DisableIntrospection, QueryDepthLimiter from strawberry_django.optimizer import DjangoOptimizerExtension @@ -19,6 +21,7 @@ ) from apps.nest.api.internal.mutations import NestMutations from apps.owasp.api.internal.queries import OwaspQuery +from settings.graphql_context import NestGraphQLContext @strawberry.type @@ -43,6 +46,16 @@ class Query( """Schema queries.""" +class NestGraphQLView(AsyncGraphQLView[NestGraphQLContext, None]): + """Nest GraphQL view.""" + + async def get_context( + self, request: HttpRequest, response: HttpResponse + ) -> NestGraphQLContext: + """Return a NestGraphQLContext instance.""" + return NestGraphQLContext(request=request, response=response) + + extensions = [ QueryDepthLimiter(max_depth=5), DjangoOptimizerExtension(), diff --git a/backend/settings/graphql_context.py b/backend/settings/graphql_context.py new file mode 100644 index 0000000000..78a330a16b --- /dev/null +++ b/backend/settings/graphql_context.py @@ -0,0 +1,14 @@ +"""Custom GraphQL context for OWASP Nest.""" + +from strawberry.django.context import StrawberryDjangoContext + +from apps.github.api.internal.dataloaders import make_github_dataloaders + + +class NestGraphQLContext(StrawberryDjangoContext): + """Nest GraphQL context.""" + + def __init__(self, *args, **kwargs) -> None: + """Initialize the context with fresh dataloader instances.""" + super().__init__(*args, **kwargs) + self.github_dataloaders = make_github_dataloaders() diff --git a/backend/settings/urls.py b/backend/settings/urls.py index 67cfae2297..798872f0ad 100644 --- a/backend/settings/urls.py +++ b/backend/settings/urls.py @@ -9,7 +9,6 @@ from django.contrib import admin from django.urls import include, path from django.views.decorators.csrf import csrf_protect -from strawberry.django.views import GraphQLView from apps.api.rest.v0 import api as api_v0 from apps.core.api.internal.algolia import algolia_search @@ -17,7 +16,7 @@ from apps.core.api.internal.status import get_status from apps.owasp.api.internal.views.urls import urlpatterns as owasp_urls from apps.slack.apps import SlackConfig -from settings.graphql import schema +from settings.graphql import NestGraphQLView, schema urlpatterns = [ path("csrf/", get_csrf_token), @@ -25,7 +24,7 @@ path( "graphql/", csrf_protect( - GraphQLView.as_view( + NestGraphQLView.as_view( schema=schema, graphql_ide="graphiql" if settings.DEBUG else None, ) diff --git a/backend/tests/unit/apps/common/api/__init__.py b/backend/tests/unit/apps/common/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/unit/apps/common/api/internal/__init__.py b/backend/tests/unit/apps/common/api/internal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/unit/apps/common/api/internal/dataloaders/__init__.py b/backend/tests/unit/apps/common/api/internal/dataloaders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/unit/apps/common/api/internal/dataloaders/utils_test.py b/backend/tests/unit/apps/common/api/internal/dataloaders/utils_test.py new file mode 100644 index 0000000000..b3d37a19ba --- /dev/null +++ b/backend/tests/unit/apps/common/api/internal/dataloaders/utils_test.py @@ -0,0 +1,141 @@ +"""Tests for shared dataloader utility functions.""" + +from types import SimpleNamespace +from typing import cast + +import pytest +from django.db.models import QuerySet + +from apps.common.api.internal.dataloaders.utils import results_by_keys + + +async def _async_gen(items): + """Yield items one-by-one as an async generator.""" + for item in items: + yield item + + +def make_qs(items) -> QuerySet: + """Wrap a list in an async-iterable that satisfies the QuerySet type annotation.""" + return cast("QuerySet", _async_gen(items)) + + +def make_item(**kwargs): + """Return a simple namespace object with the given attributes.""" + return SimpleNamespace(**kwargs) + + +class TestResultsByKeys: + """Tests for results_by_keys.""" + + @pytest.mark.asyncio + async def test_basic_mapping(self): + """Each key maps to exactly one value.""" + items = [ + make_item(org_id=1, repo="repo-a"), + make_item(org_id=2, repo="repo-b"), + ] + result = await results_by_keys(make_qs(items), [1, 2], "org_id", "repo") + assert result == [["repo-a"], ["repo-b"]] + + @pytest.mark.asyncio + async def test_empty_queryset(self): + """Empty queryset returns an empty list for every key.""" + result = await results_by_keys(make_qs([]), [1, 2, 3], "org_id", "repo") + assert result == [[], [], []] + + @pytest.mark.asyncio + async def test_empty_keys(self): + """Empty keys list returns an empty list regardless of queryset content.""" + items = [make_item(org_id=1, repo="repo-a")] + result = await results_by_keys(make_qs(items), [], "org_id", "repo") + assert result == [] + + @pytest.mark.asyncio + async def test_key_absent_from_results(self): + """A key with no matching queryset items gets an empty list.""" + items = [make_item(org_id=1, repo="repo-a")] + result = await results_by_keys(make_qs(items), [1, 2], "org_id", "repo") + assert result == [["repo-a"], []] + + @pytest.mark.asyncio + async def test_multiple_values_per_key(self): + """Multiple items sharing the same key are collected into one list.""" + items = [ + make_item(org_id=1, repo="repo-a"), + make_item(org_id=1, repo="repo-b"), + make_item(org_id=2, repo="repo-c"), + ] + result = await results_by_keys(make_qs(items), [1, 2], "org_id", "repo") + assert result == [["repo-a", "repo-b"], ["repo-c"]] + + @pytest.mark.asyncio + async def test_order_matches_keys_not_queryset(self): + """The output order follows ``keys``, not the queryset iteration order.""" + items = [ + make_item(org_id=3, repo="repo-c"), + make_item(org_id=1, repo="repo-a"), + make_item(org_id=2, repo="repo-b"), + ] + result = await results_by_keys(make_qs(items), [1, 2, 3], "org_id", "repo") + assert result == [["repo-a"], ["repo-b"], ["repo-c"]] + + @pytest.mark.asyncio + async def test_items_not_in_keys_are_ignored(self): + """Items whose key is not present in ``keys`` are silently discarded.""" + items = [ + make_item(org_id=1, repo="repo-a"), + make_item(org_id=99, repo="orphan"), # key 99 is not requested + ] + result = await results_by_keys(make_qs(items), [1], "org_id", "repo") + assert result == [["repo-a"]] + + @pytest.mark.asyncio + async def test_arbitrary_field_names(self): + """key_field and value_field are applied via getattr, so any attribute names work.""" + items = [ + make_item(chapter_id="us", city="New York"), + make_item(chapter_id="us", city="San Francisco"), + make_item(chapter_id="uk", city="London"), + ] + result = await results_by_keys(make_qs(items), ["uk", "us"], "chapter_id", "city") + assert result == [["London"], ["New York", "San Francisco"]] + + @pytest.mark.asyncio + async def test_duplicate_keys_in_keys_list(self): + """A key appearing multiple times in ``keys`` produces one entry per occurrence.""" + items = [make_item(org_id=1, repo="repo-a")] + result = await results_by_keys(make_qs(items), [1, 1], "org_id", "repo") + assert result == [["repo-a"], ["repo-a"]] + + @pytest.mark.parametrize( + ("items", "keys", "key_field", "value_field", "expected"), + [ + ( + [make_item(pk=10, name="alpha")], + [10], + "pk", + "name", + [["alpha"]], + ), + ( + [], + [10, 20], + "pk", + "name", + [[], []], + ), + ( + [make_item(pk=1, name="x"), make_item(pk=1, name="y"), make_item(pk=2, name="z")], + [2, 1], + "pk", + "name", + [["z"], ["x", "y"]], + ), + ], + ) + @pytest.mark.asyncio + async def test_parametrized_scenarios(self, items, keys, key_field, value_field, expected): + """Parametrized spot-checks covering single value, empty, and multi-value cases.""" + result = await results_by_keys(make_qs(items), keys, key_field, value_field) + assert result == expected diff --git a/backend/tests/unit/apps/github/api/__init__.py b/backend/tests/unit/apps/github/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/unit/apps/github/api/internal/__init__.py b/backend/tests/unit/apps/github/api/internal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/unit/apps/github/api/internal/dataloaders/__init__.py b/backend/tests/unit/apps/github/api/internal/dataloaders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/unit/apps/github/api/internal/dataloaders/interested_users_test.py b/backend/tests/unit/apps/github/api/internal/dataloaders/interested_users_test.py new file mode 100644 index 0000000000..5caf4d8dc1 --- /dev/null +++ b/backend/tests/unit/apps/github/api/internal/dataloaders/interested_users_test.py @@ -0,0 +1,154 @@ +"""Tests for the interested_users dataloader.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from strawberry.dataloader import DataLoader + +from apps.github.api.internal.dataloaders.interested_users import ( + load_interested_users, + make_interested_users_loader, +) + + +class TestLoadInterestedUsers: + """Tests for load_interested_users.""" + + @patch( + "apps.github.api.internal.dataloaders.interested_users.results_by_keys", + new_callable=AsyncMock, + ) + @patch("apps.github.api.internal.dataloaders.interested_users.IssueUserInterest") + @pytest.mark.asyncio + async def test_builds_queryset_with_correct_chain(self, mock_interested, mock_results_by_keys): + """Queryset is built with select_related, filter, and order_by in the right order.""" + issue_ids = [1, 2, 3] + mock_queryset = MagicMock() + mock_interested_filter = mock_interested.objects.select_related.return_value.filter + mock_interested_filter.return_value.order_by.return_value = mock_queryset + mock_results_by_keys.return_value = [[], [], []] + + await load_interested_users(issue_ids) + + mock_interested.objects.select_related.assert_called_once_with("user__owasp_profile") + mock_filter = mock_interested.objects.select_related.return_value.filter + mock_filter.assert_called_once_with(issue_id__in=issue_ids) + mock_filter.return_value.order_by.assert_called_once_with("user__login") + + @patch( + "apps.github.api.internal.dataloaders.interested_users.results_by_keys", + new_callable=AsyncMock, + ) + @patch("apps.github.api.internal.dataloaders.interested_users.IssueUserInterest") + @pytest.mark.asyncio + async def test_delegates_to_results_by_keys_correct_args( + self, mock_interested, mock_results_by_keys + ): + """results_by_keys receives the queryset, issue_ids, and correct field names.""" + issue_ids = [10, 20] + mock_queryset = MagicMock() + mock_interested_filter = mock_interested.objects.select_related.return_value.filter + mock_interested_filter.return_value.order_by.return_value = mock_queryset + mock_results_by_keys.return_value = [[], []] + + await load_interested_users(issue_ids) + + mock_results_by_keys.assert_called_once_with( + mock_queryset, issue_ids, key_field="issue_id", value_field="user" + ) + + @patch( + "apps.github.api.internal.dataloaders.interested_users.results_by_keys", + new_callable=AsyncMock, + ) + @patch("apps.github.api.internal.dataloaders.interested_users.IssueUserInterest") + @pytest.mark.asyncio + async def test_returns_result_from_results_by_keys( + self, mock_interested, mock_results_by_keys + ): + """The return value is exactly what results_by_keys resolves to.""" + mock_user_a = MagicMock() + mock_user_b = MagicMock() + expected = [[mock_user_a, mock_user_b], [], [mock_user_a]] + mock_results_by_keys.return_value = expected + + result = await load_interested_users([1, 2, 3]) + + assert result is expected + + @patch( + "apps.github.api.internal.dataloaders.interested_users.results_by_keys", + new_callable=AsyncMock, + ) + @patch("apps.github.api.internal.dataloaders.interested_users.IssueUserInterest") + @pytest.mark.asyncio + async def test_empty_issue_ids(self, mock_interested, mock_results_by_keys): + """An empty issue_ids list results in an empty filter and empty return.""" + mock_interested_filter = mock_interested.objects.select_related.return_value.filter + mock_interested_filter.return_value.order_by.return_value = MagicMock() + mock_results_by_keys.return_value = [] + + result = await load_interested_users([]) + + mock_interested.objects.select_related.return_value.filter.assert_called_once_with( + issue_id__in=[] + ) + assert result == [] + + @patch( + "apps.github.api.internal.dataloaders.interested_users.results_by_keys", + new_callable=AsyncMock, + ) + @patch("apps.github.api.internal.dataloaders.interested_users.IssueUserInterest") + @pytest.mark.asyncio + async def test_single_issue_id(self, mock_interested, mock_results_by_keys): + """A single-element list is handled correctly end-to-end.""" + mock_user = MagicMock() + mock_interested_filter = mock_interested.objects.select_related.return_value.filter + mock_interested_filter.return_value.order_by.return_value = MagicMock() + mock_results_by_keys.return_value = [[mock_user]] + + result = await load_interested_users([42]) + + mock_interested.objects.select_related.return_value.filter.assert_called_once_with( + issue_id__in=[42] + ) + assert result == [[mock_user]] + + @patch( + "apps.github.api.internal.dataloaders.interested_users.results_by_keys", + new_callable=AsyncMock, + ) + @patch("apps.github.api.internal.dataloaders.interested_users.IssueUserInterest") + @pytest.mark.asyncio + async def test_preserves_issue_ids_order(self, mock_interested, mock_results_by_keys): + """The issue_ids list is forwarded to results_by_keys unchanged, preserving order.""" + issue_ids = [30, 10, 20] + mock_interested_filter = mock_interested.objects.select_related.return_value.filter + mock_interested_filter.return_value.order_by.return_value = MagicMock() + mock_results_by_keys.return_value = [[], [], []] + + await load_interested_users(issue_ids) + + _, positional_args, _ = mock_results_by_keys.mock_calls[0] + assert positional_args[1] is issue_ids + + +class TestMakeInterestedUsersLoader: + """Tests for make_interested_users_loader.""" + + def test_returns_dataloader_instance(self): + """Factory always returns a DataLoader.""" + loader = make_interested_users_loader() + assert isinstance(loader, DataLoader) + + def test_returns_new_instance_on_each_call(self): + """Each call produces a distinct DataLoader for per-request isolation.""" + loader1 = make_interested_users_loader() + loader2 = make_interested_users_loader() + assert loader1 is not loader2 + + def test_load_fn_is_load_interested_users(self): + """The DataLoader is wired to load_interested_users.""" + loader = make_interested_users_loader() + assert loader.load_fn is load_interested_users diff --git a/backend/tests/unit/apps/github/api/internal/nodes/issue_test.py b/backend/tests/unit/apps/github/api/internal/nodes/issue_test.py index fd1b004d52..52b021f287 100644 --- a/backend/tests/unit/apps/github/api/internal/nodes/issue_test.py +++ b/backend/tests/unit/apps/github/api/internal/nodes/issue_test.py @@ -1,8 +1,11 @@ """Test cases for IssueNode.""" from datetime import UTC, datetime -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch +import pytest + +from apps.github.api.internal.dataloaders import INTERESTED_USERS_LOADER from apps.github.api.internal.nodes.issue import IssueNode from tests.unit.apps.common.graphql_node_base_test import GraphQLNodeBaseTest @@ -150,24 +153,28 @@ def test_is_merged_false(self): result = field.base_resolver.wrapped_func(None, mock_issue) assert not result - def test_interested_users(self): - """Test interested_users field returns list of users from prefetched interests_users.""" + @pytest.mark.asyncio + async def test_interested_users(self): + """Test interested_users field returns users from the interested users dataloader.""" mock_issue = Mock() + mock_issue.pk = 123 mock_user1 = Mock() mock_user1.login = "user1" mock_user2 = Mock() mock_user2.login = "user2" - mock_interest1 = Mock() - mock_interest1.user = mock_user1 - mock_interest2 = Mock() - mock_interest2.user = mock_user2 + mock_loader = Mock() + mock_loader.load = AsyncMock(return_value=[mock_user1, mock_user2]) - mock_issue.interests_users = [mock_interest1, mock_interest2] + mock_info = Mock() + mock_info.context = Mock() + mock_info.context.github_dataloaders = {INTERESTED_USERS_LOADER: mock_loader} field = self._get_field_by_name("interested_users", IssueNode) - result = field.base_resolver.wrapped_func(None, mock_issue) + result = await field.base_resolver.wrapped_func(None, mock_issue, mock_info) + assert result == [mock_user1, mock_user2] + mock_loader.load.assert_awaited_once_with(mock_issue.pk) def test_task_deadline_with_bulk_load_mapping(self): """Test task_deadline field when mapping exists with issue deadline.""" diff --git a/backend/tests/unit/apps/github/api/internal/queries/organization_test.py b/backend/tests/unit/apps/github/api/internal/queries/organization_test.py index 0b439bff17..3db0b03fdc 100644 --- a/backend/tests/unit/apps/github/api/internal/queries/organization_test.py +++ b/backend/tests/unit/apps/github/api/internal/queries/organization_test.py @@ -1,6 +1,6 @@ """Test cases for OrganizationQuery.""" -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest @@ -19,32 +19,35 @@ def mock_organization(self): org.login = "owasp" return org - @patch("apps.github.models.organization.Organization.objects.get") - def test_organization_found(self, mock_get, mock_organization): + @patch("apps.github.models.organization.Organization.objects.aget", new_callable=AsyncMock) + @pytest.mark.asyncio + async def test_organization_found(self, mock_aget, mock_organization): """Test fetching organization when it exists.""" - mock_get.return_value = mock_organization + mock_aget.return_value = mock_organization - result = OrganizationQuery().organization(login="owasp") + result = await OrganizationQuery().organization(login="owasp") assert result == mock_organization - mock_get.assert_called_once_with(is_owasp_related_organization=True, login="owasp") + mock_aget.assert_called_once_with(is_owasp_related_organization=True, login="owasp") - @patch("apps.github.models.organization.Organization.objects.get") - def test_organization_not_found(self, mock_get): + @patch("apps.github.models.organization.Organization.objects.aget", new_callable=AsyncMock) + @pytest.mark.asyncio + async def test_organization_not_found(self, mock_aget): """Test fetching organization when it doesn't exist.""" - mock_get.side_effect = Organization.DoesNotExist() + mock_aget.side_effect = Organization.DoesNotExist() - result = OrganizationQuery().organization(login="nonexistent") + result = await OrganizationQuery().organization(login="nonexistent") assert result is None - mock_get.assert_called_once_with(is_owasp_related_organization=True, login="nonexistent") + mock_aget.assert_called_once_with(is_owasp_related_organization=True, login="nonexistent") - @patch("apps.github.models.organization.Organization.objects.get") - def test_organization_with_different_login(self, mock_get, mock_organization): + @patch("apps.github.models.organization.Organization.objects.aget", new_callable=AsyncMock) + @pytest.mark.asyncio + async def test_organization_with_different_login(self, mock_aget, mock_organization): """Test fetching organization with different login.""" - mock_get.return_value = mock_organization + mock_aget.return_value = mock_organization - result = OrganizationQuery().organization(login="test-org") + result = await OrganizationQuery().organization(login="test-org") assert result == mock_organization - mock_get.assert_called_once_with(is_owasp_related_organization=True, login="test-org") + mock_aget.assert_called_once_with(is_owasp_related_organization=True, login="test-org") diff --git a/backend/tests/unit/apps/mentorship/api/internal/queries/api_queries_module_test.py b/backend/tests/unit/apps/mentorship/api/internal/queries/api_queries_module_test.py index b482ee3f95..763a902334 100644 --- a/backend/tests/unit/apps/mentorship/api/internal/queries/api_queries_module_test.py +++ b/backend/tests/unit/apps/mentorship/api/internal/queries/api_queries_module_test.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest import strawberry @@ -45,10 +45,13 @@ class TestModuleQuery: """Tests for ModuleQuery.""" @patch("apps.mentorship.api.internal.queries.module.Module.objects.filter") - @patch("apps.mentorship.api.internal.queries.module.Program.objects.get") - def test_get_program_modules_success( + @patch( + "apps.mentorship.api.internal.queries.module.Program.objects.aget", new_callable=AsyncMock + ) + @pytest.mark.asyncio + async def test_get_program_modules_success( self, - mock_program_get: MagicMock, + mock_program_aget: AsyncMock, mock_module_filter: MagicMock, mock_info: MagicMock, api_module_queries, @@ -56,50 +59,58 @@ def test_get_program_modules_success( """Test successful retrieval of modules by program key.""" mock_program = MagicMock(spec=Program) mock_program.status = Program.ProgramStatus.PUBLISHED - mock_program_get.return_value = mock_program + mock_program_aget.return_value = mock_program mock_module = MagicMock(spec=Module) mock_module_filter_related = mock_module_filter.return_value.select_related.return_value mock_module_filter_related.prefetch_related.return_value.order_by.return_value = [ mock_module ] - result = api_module_queries.get_program_modules(info=mock_info, program_key="program1") + result = await api_module_queries.get_program_modules( + info=mock_info, program_key="program1" + ) assert result == [mock_module] - mock_program_get.assert_called_once_with(key="program1") + mock_program_aget.assert_called_once_with(key="program1") mock_module_filter.assert_called_once_with(program=mock_program) - @patch("apps.mentorship.api.internal.queries.module.Program.objects.get") - def test_get_program_modules_empty( + @patch( + "apps.mentorship.api.internal.queries.module.Program.objects.aget", new_callable=AsyncMock + ) + @pytest.mark.asyncio + async def test_get_program_modules_empty( self, - mock_program_get: MagicMock, + mock_program_aget: AsyncMock, mock_info: MagicMock, api_module_queries, ) -> None: """Test retrieval of modules returns empty list if program not found.""" - mock_program_get.side_effect = Program.DoesNotExist + mock_program_aget.side_effect = Program.DoesNotExist - result = api_module_queries.get_program_modules( + result = await api_module_queries.get_program_modules( info=mock_info, program_key="nonexistent_program" ) assert result == [] - mock_program_get.assert_called_once_with(key="nonexistent_program") + mock_program_aget.assert_called_once_with(key="nonexistent_program") - @patch("apps.mentorship.api.internal.queries.module.Program.objects.get") - def test_get_program_modules_hidden_for_draft_program( + @patch( + "apps.mentorship.api.internal.queries.module.Program.objects.aget", new_callable=AsyncMock + ) + @pytest.mark.asyncio + async def test_get_program_modules_hidden_for_draft_program( self, - mock_program_get: MagicMock, + mock_program_aget: AsyncMock, mock_anonymous_info: MagicMock, api_module_queries, ) -> None: """Test that modules of a draft program are hidden from anonymous users.""" mock_program = MagicMock(spec=Program) mock_program.status = Program.ProgramStatus.DRAFT - mock_program_get.return_value = mock_program + mock_program_aget.return_value = mock_program mock_program.user_has_access.return_value = False - result = api_module_queries.get_program_modules( + result = await api_module_queries.get_program_modules( info=mock_anonymous_info, program_key="draft-program" ) diff --git a/backend/tests/unit/apps/mentorship/api/internal/queries/api_queries_program_test.py b/backend/tests/unit/apps/mentorship/api/internal/queries/api_queries_program_test.py index 2a5a89e31e..a29bca0ca6 100644 --- a/backend/tests/unit/apps/mentorship/api/internal/queries/api_queries_program_test.py +++ b/backend/tests/unit/apps/mentorship/api/internal/queries/api_queries_program_test.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest import strawberry @@ -46,39 +46,44 @@ class TestGetProgram: """Tests for the get_program query.""" @patch("apps.mentorship.api.internal.queries.program.Program.objects.prefetch_related") - def test_get_program_success( + @pytest.mark.asyncio + async def test_get_program_success( self, mock_program_prefetch_related: MagicMock, mock_info: MagicMock, api_program_queries ) -> None: """Test successful retrieval of a published program by key.""" mock_program = MagicMock(spec=Program) mock_program.status = Program.ProgramStatus.PUBLISHED - mock_program_prefetch_related.return_value.get.return_value = mock_program + mock_program_prefetch_related.return_value.aget = AsyncMock(return_value=mock_program) - result = api_program_queries.get_program(info=mock_info, program_key="program1") + result = await api_program_queries.get_program(info=mock_info, program_key="program1") assert result == mock_program mock_program_prefetch_related.assert_called_once_with( "admins__github_user", "admins__nest_user" ) - mock_program_prefetch_related.return_value.get.assert_called_once_with(key="program1") + mock_program_prefetch_related.return_value.aget.assert_called_once_with(key="program1") @patch("apps.mentorship.api.internal.queries.program.Program.objects.prefetch_related") - def test_get_program_does_not_exist( + @pytest.mark.asyncio + async def test_get_program_does_not_exist( self, mock_program_prefetch_related: MagicMock, mock_info: MagicMock, api_program_queries ) -> None: """Test when the program does not exist.""" - mock_program_prefetch_related.return_value.get.side_effect = Program.DoesNotExist + mock_program_prefetch_related.return_value.aget = AsyncMock( + side_effect=Program.DoesNotExist + ) - result = api_program_queries.get_program(info=mock_info, program_key="nonexistent") + result = await api_program_queries.get_program(info=mock_info, program_key="nonexistent") assert result is None mock_program_prefetch_related.assert_called_once_with( "admins__github_user", "admins__nest_user" ) - mock_program_prefetch_related.return_value.get.assert_called_once_with(key="nonexistent") + mock_program_prefetch_related.return_value.aget.assert_called_once_with(key="nonexistent") @patch("apps.mentorship.api.internal.queries.program.Program.objects.prefetch_related") - def test_get_draft_program_hidden_for_anonymous_user( + @pytest.mark.asyncio + async def test_get_draft_program_hidden_for_anonymous_user( self, mock_program_prefetch_related: MagicMock, mock_anonymous_info: MagicMock, @@ -87,17 +92,18 @@ def test_get_draft_program_hidden_for_anonymous_user( """Test that a draft program is not visible to anonymous users.""" mock_program = MagicMock(spec=Program) mock_program.status = Program.ProgramStatus.DRAFT - mock_program_prefetch_related.return_value.get.return_value = mock_program + mock_program_prefetch_related.return_value.aget = AsyncMock(return_value=mock_program) mock_program.user_has_access.return_value = False - result = api_program_queries.get_program( + result = await api_program_queries.get_program( info=mock_anonymous_info, program_key="draft-program" ) assert result is None @patch("apps.mentorship.api.internal.queries.program.Program.objects.prefetch_related") - def test_get_draft_program_visible_for_admin( + @pytest.mark.asyncio + async def test_get_draft_program_visible_for_admin( self, mock_program_prefetch_related: MagicMock, mock_info: MagicMock, @@ -106,10 +112,10 @@ def test_get_draft_program_visible_for_admin( """Test that a draft program is visible to an admin.""" mock_program = MagicMock(spec=Program) mock_program.status = Program.ProgramStatus.DRAFT - mock_program_prefetch_related.return_value.get.return_value = mock_program + mock_program_prefetch_related.return_value.aget = AsyncMock(return_value=mock_program) mock_program.user_has_access.return_value = True - result = api_program_queries.get_program(info=mock_info, program_key="draft-program") + result = await api_program_queries.get_program(info=mock_info, program_key="draft-program") assert result == mock_program diff --git a/backend/tests/unit/apps/owasp/api/internal/queries/stats_test.py b/backend/tests/unit/apps/owasp/api/internal/queries/stats_test.py index f733484a4a..330cac8992 100644 --- a/backend/tests/unit/apps/owasp/api/internal/queries/stats_test.py +++ b/backend/tests/unit/apps/owasp/api/internal/queries/stats_test.py @@ -1,6 +1,8 @@ """Tests for StatsQuery.""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from apps.owasp.api.internal.queries.stats import StatsQuery @@ -8,7 +10,8 @@ class TestStatsQuery: """Test cases for StatsQuery.""" - def test_stats_overview_returns_node(self): + @pytest.mark.asyncio + async def test_stats_overview_returns_node(self): """Test stats_overview returns StatsNode with calculated values.""" mock_workspace = MagicMock() mock_workspace.total_members_count = 5500 @@ -19,15 +22,20 @@ def test_stats_overview_returns_node(self): patch("apps.owasp.api.internal.queries.stats.User") as mock_user, patch("apps.owasp.api.internal.queries.stats.Workspace") as mock_workspace_cls, ): - mock_project.active_projects_count.return_value = 275 - mock_chapter.active_chapters_count.return_value = 342 - mock_user.objects.count.return_value = 15234 - mock_filter = mock_chapter.objects.filter.return_value.exclude.return_value - mock_filter.values.return_value.distinct.return_value.count.return_value = 98 - mock_workspace_cls.get_default_workspace.return_value = mock_workspace + mock_project.active_projects.acount = AsyncMock(return_value=275) + mock_chapter.active_chapters.acount = AsyncMock(return_value=342) + mock_user.objects.acount = AsyncMock(return_value=15234) + + mock_filter_chain = MagicMock() + mock_filter_chain.acount = AsyncMock(return_value=98) + mock_exclude = mock_chapter.objects.filter.return_value.exclude + mock_exclude.return_value.values.return_value.distinct.return_value = mock_filter_chain + + mock_workspace_cls.objects.filter.return_value.afirst = AsyncMock( + return_value=mock_workspace + ) - query = StatsQuery() - result = query.stats_overview() + result = await StatsQuery().stats_overview() assert result.active_projects_stats == 270 assert result.active_chapters_stats == 340 @@ -35,7 +43,8 @@ def test_stats_overview_returns_node(self): assert result.countries_stats == 90 assert result.slack_workspace_stats == 5000 - def test_stats_overview_no_workspace(self): + @pytest.mark.asyncio + async def test_stats_overview_no_workspace(self): """Test stats_overview when no default workspace exists.""" with ( patch("apps.owasp.api.internal.queries.stats.Project") as mock_project, @@ -43,15 +52,18 @@ def test_stats_overview_no_workspace(self): patch("apps.owasp.api.internal.queries.stats.User") as mock_user, patch("apps.owasp.api.internal.queries.stats.Workspace") as mock_workspace_cls, ): - mock_project.active_projects_count.return_value = 10 - mock_chapter.active_chapters_count.return_value = 10 - mock_user.objects.count.return_value = 1000 - mock_filter = mock_chapter.objects.filter.return_value.exclude.return_value - mock_filter.values.return_value.distinct.return_value.count.return_value = 10 - mock_workspace_cls.get_default_workspace.return_value = None - - query = StatsQuery() - result = query.stats_overview() + mock_project.active_projects.acount = AsyncMock(return_value=10) + mock_chapter.active_chapters.acount = AsyncMock(return_value=10) + mock_user.objects.acount = AsyncMock(return_value=1000) + + mock_filter_chain = MagicMock() + mock_filter_chain.acount = AsyncMock(return_value=10) + mock_exclude = mock_chapter.objects.filter.return_value.exclude + mock_exclude.return_value.values.return_value.distinct.return_value = mock_filter_chain + + mock_workspace_cls.objects.filter.return_value.afirst = AsyncMock(return_value=None) + + result = await StatsQuery().stats_overview() assert result.slack_workspace_stats == 0 def test_stats_overview_has_strawberry_definition(self):