Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
29 changes: 29 additions & 0 deletions backend/apps/common/api/internal/dataloaders/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Shared utilities for GraphQL dataloaders."""

from collections import defaultdict

from django.db.models import QuerySet


async def results_by_keys[K, V](
queryset: QuerySet, keys: list[K], key_field: str, value_field: str
) -> list[list[V]]:
"""Map a grouped-results dict back to an ordered list matching ``keys``.

Args:
queryset: The queryset to iterate over.
keys: A list of keys to map the results to, in the desired order.
key_field: The name of the attribute on each item that contains the key.
value_field: The name of the attribute on each item that contains the value.

Returns:
A list of result-lists, one per key, in the same order as ``keys``.

"""
mapping: dict[K, list[V]] = defaultdict(list)
async for item in queryset:
key = getattr(item, key_field)
value = getattr(item, value_field)
mapping[key].append(value)

return [mapping.get(key, []) for key in keys]
10 changes: 10 additions & 0 deletions backend/apps/github/api/internal/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from strawberry.dataloader import DataLoader

from apps.github.api.internal.dataloaders.interested_users import make_interested_users_loader


def make_github_dataloaders() -> dict[str, DataLoader]:
"""Return a dict of dataloader instances for GitHub API resolvers."""
return {
"interested_users_loader": make_interested_users_loader(),
}
Comment thread
ahmedxgouda marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""DataLoader for interested users per issue."""

from strawberry.dataloader import DataLoader

from apps.common.api.internal.dataloaders.utils import results_by_keys
from apps.github.models.user import User
from apps.mentorship.models.issue_user_interest import IssueUserInterest


async def load_interested_users(issue_ids: list[int]) -> list[list[User]]:
"""Batch-load interested users for the given issue IDs in a single query."""
interests = (
IssueUserInterest.objects.select_related("user__owasp_profile")
.filter(issue_id__in=issue_ids)
.order_by("user__login")
)
return await results_by_keys(interests, issue_ids, key_field="issue_id", value_field="user")


def make_interested_users_loader() -> DataLoader:
"""Return a per-request DataLoader instance."""
return DataLoader(load_fn=load_interested_users)
17 changes: 3 additions & 14 deletions backend/apps/github/api/internal/nodes/issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from apps.github.api.internal.nodes.user import UserNode
from apps.github.models.issue import Issue
from apps.github.models.pull_request import PullRequest
from apps.mentorship.models.issue_user_interest import IssueUserInterest
from apps.mentorship.models.task import Task

MERGED_PULL_REQUESTS_PREFETCH = Prefetch(
Expand Down Expand Up @@ -79,20 +78,10 @@ def is_merged(self, root: Issue) -> bool:
"""Return True if this issue has at least one merged pull request."""
return bool(getattr(root, "merged_pull_requests", None))

@strawberry_django.field(
prefetch_related=[
Prefetch(
"participant_interests",
queryset=IssueUserInterest.objects.select_related("user__owasp_profile").order_by(
"user__login"
),
to_attr="interests_users",
)
]
)
def interested_users(self, root: Issue) -> list[UserNode]:
@strawberry_django.field
async def interested_users(self, root: Issue, info: Info) -> list[UserNode]:
"""Return all users who have expressed interest in this issue."""
return [interest.user for interest in getattr(root, "interests_users", [])]
return await info.context.github_dataloaders["interested_users_loader"].load(root.pk)

@strawberry.field
def task_deadline(self, root: Issue, info: Info) -> datetime | None:
Expand Down
13 changes: 13 additions & 0 deletions backend/settings/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import strawberry
from django.conf import settings
from django.http import HttpRequest, HttpResponse
from strawberry.django.views import AsyncGraphQLView
from strawberry.extensions import DisableIntrospection, QueryDepthLimiter
from strawberry_django.optimizer import DjangoOptimizerExtension

Expand All @@ -19,6 +21,7 @@
)
from apps.nest.api.internal.mutations import NestMutations
from apps.owasp.api.internal.queries import OwaspQuery
from settings.graphql_context import NestGraphQLContext


@strawberry.type
Expand All @@ -43,6 +46,16 @@ class Query(
"""Schema queries."""


class NestGraphQLView(AsyncGraphQLView[NestGraphQLContext, None]):
"""Nest GraphQL view."""

async def get_context(
self, request: HttpRequest, response: HttpResponse
) -> NestGraphQLContext:
"""Return a NestGraphQLContext instance."""
return NestGraphQLContext(request=request, response=response)


extensions = [
QueryDepthLimiter(max_depth=5),
DjangoOptimizerExtension(),
Expand Down
14 changes: 14 additions & 0 deletions backend/settings/graphql_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Custom GraphQL context for OWASP Nest."""

from strawberry.django.context import StrawberryDjangoContext

from apps.github.api.internal.dataloaders import make_github_dataloaders


class NestGraphQLContext(StrawberryDjangoContext):
"""Nest GraphQL context."""

def __init__(self, *args, **kwargs) -> None:
"""Initialize the context with fresh dataloader instances."""
super().__init__(*args, **kwargs)
self.github_dataloaders = make_github_dataloaders()
5 changes: 2 additions & 3 deletions backend/settings/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,22 @@
from django.contrib import admin
from django.urls import include, path
from django.views.decorators.csrf import csrf_protect
from strawberry.django.views import GraphQLView

from apps.api.rest.v0 import api as api_v0
from apps.core.api.internal.algolia import algolia_search
from apps.core.api.internal.csrf import get_csrf_token
from apps.core.api.internal.status import get_status
from apps.owasp.api.internal.views.urls import urlpatterns as owasp_urls
from apps.slack.apps import SlackConfig
from settings.graphql import schema
from settings.graphql import NestGraphQLView, schema

urlpatterns = [
path("csrf/", get_csrf_token),
path("idx/", csrf_protect(algolia_search)),
path(
"graphql/",
csrf_protect(
GraphQLView.as_view(
NestGraphQLView.as_view(
schema=schema,
graphql_ide="graphiql" if settings.DEBUG else None,
)
Expand Down
20 changes: 12 additions & 8 deletions backend/tests/unit/apps/github/api/internal/nodes/issue_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Test cases for IssueNode."""

import asyncio
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch

from apps.github.api.internal.nodes.issue import IssueNode
from tests.unit.apps.common.graphql_node_base_test import GraphQLNodeBaseTest
Expand Down Expand Up @@ -151,23 +152,26 @@ def test_is_merged_false(self):
assert not result

def test_interested_users(self):
"""Test interested_users field returns list of users from prefetched interests_users."""
"""Test interested_users field returns users from the interested users dataloader."""
mock_issue = Mock()
mock_issue.pk = 123
mock_user1 = Mock()
mock_user1.login = "user1"
mock_user2 = Mock()
mock_user2.login = "user2"

mock_interest1 = Mock()
mock_interest1.user = mock_user1
mock_interest2 = Mock()
mock_interest2.user = mock_user2
mock_loader = Mock()
mock_loader.load = AsyncMock(return_value=[mock_user1, mock_user2])

mock_issue.interests_users = [mock_interest1, mock_interest2]
mock_info = Mock()
mock_info.context = Mock()
mock_info.context.github_dataloaders = {"interested_users_loader": mock_loader}

field = self._get_field_by_name("interested_users", IssueNode)
result = field.base_resolver.wrapped_func(None, mock_issue)
result = asyncio.run(field.base_resolver.wrapped_func(None, mock_issue, mock_info))

assert result == [mock_user1, mock_user2]
mock_loader.load.assert_awaited_once_with(mock_issue.pk)

def test_task_deadline_with_bulk_load_mapping(self):
"""Test task_deadline field when mapping exists with issue deadline."""
Expand Down
Loading