From 97a5317bd98403e6cb32a670233db543ca47e8f5 Mon Sep 17 00:00:00 2001 From: aman Date: Wed, 27 May 2026 18:42:02 +0530 Subject: [PATCH] fix(api): dedupe access_pairs and map invalid-id/unknown-perm to 4xx Group successCheckPairs by resource id in ListProjectsByCurrentUser, ListServiceUserProjects, and ListCurrentUserGroups so each resource appears once. Resolve permissions once in fetchAccessPairsOnResource, drop unknown names and duplicate inputs. Validate Principal id in project.Service.List and service-user id in serviceuser.Service.Get and map the typed errors to InvalidArgument in the handlers. Co-Authored-By: Claude Opus 4.7 (1M context) --- core/project/service.go | 7 +- core/project/service_test.go | 39 ++- core/serviceuser/service.go | 3 + core/serviceuser/service_test.go | 50 ++++ .../api/v1beta1connect/permission_check.go | 65 +++-- internal/api/v1beta1connect/serviceuser.go | 40 ++- .../api/v1beta1connect/serviceuser_test.go | 168 +++++++++++ internal/api/v1beta1connect/user.go | 58 ++-- internal/api/v1beta1connect/user_test.go | 264 +++++++++++++++++- 9 files changed, 618 insertions(+), 76 deletions(-) diff --git a/core/project/service.go b/core/project/service.go index deac0f329..c48c86c97 100644 --- a/core/project/service.go +++ b/core/project/service.go @@ -133,8 +133,11 @@ func (s Service) Create(ctx context.Context, prj Project) (Project, error) { func (s Service) List(ctx context.Context, f Filter) ([]Project, error) { if f.Principal != nil { - if f.Principal.ID == "" || f.Principal.Type == "" { - return nil, fmt.Errorf("project: invalid principal filter") + if !utils.IsValidUUID(f.Principal.ID) { + return nil, ErrInvalidUUID + } + if f.Principal.Type == "" { + return nil, ErrInvalidPrincipalType } if s.membershipService == nil { return nil, fmt.Errorf("project: membership service not wired") diff --git a/core/project/service_test.go b/core/project/service_test.go index 7289ff1b8..dfdad3c41 100644 --- a/core/project/service_test.go +++ b/core/project/service_test.go @@ -288,14 +288,15 @@ func TestService_List(t *testing.T) { func TestService_List_WithPrincipal(t *testing.T) { ctx := context.Background() - userPrincipal := authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal} + userPrincipal := authenticate.Principal{ID: "68f86fec-eb87-49f0-9be0-8d99b00a4a9c", Type: schema.UserPrincipal} tests := []struct { - name string - setup func(*testing.T) *project.Service - filter project.Filter - want []project.Project - wantErr bool + name string + setup func(*testing.T) *project.Service + filter project.Filter + want []project.Project + wantErr bool + wantErrIs error }{ { name: "errors when membership service is not wired", @@ -309,24 +310,37 @@ func TestService_List_WithPrincipal(t *testing.T) { wantErr: true, }, { - name: "errors when Principal has empty ID", + name: "returns ErrInvalidUUID when Principal has empty ID", filter: project.Filter{Principal: &authenticate.Principal{Type: schema.UserPrincipal}}, setup: func(t *testing.T) *project.Service { t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) }, - wantErr: true, + wantErr: true, + wantErrIs: project.ErrInvalidUUID, }, { - name: "errors when Principal has empty Type", - filter: project.Filter{Principal: &authenticate.Principal{ID: "user-id"}}, + name: "returns ErrInvalidUUID when Principal ID is not a valid UUID", + filter: project.Filter{Principal: &authenticate.Principal{ID: "not-a-uuid", Type: schema.UserPrincipal}}, setup: func(t *testing.T) *project.Service { t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) }, - wantErr: true, + wantErr: true, + wantErrIs: project.ErrInvalidUUID, + }, + { + name: "returns ErrInvalidPrincipalType when Principal has empty Type", + filter: project.Filter{Principal: &authenticate.Principal{ID: "68f86fec-eb87-49f0-9be0-8d99b00a4a9c"}}, + setup: func(t *testing.T) *project.Service { + t.Helper() + repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) + return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + }, + wantErr: true, + wantErrIs: project.ErrInvalidPrincipalType, }, { name: "returns projects from the membership shim", @@ -462,6 +476,9 @@ func TestService_List_WithPrincipal(t *testing.T) { t.Errorf("List() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) { + t.Errorf("List() error = %v, want errors.Is(%v)", err, tt.wantErrIs) + } if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("List() mismatch (-want +got):\n%s", diff) } diff --git a/core/serviceuser/service.go b/core/serviceuser/service.go index 0ae9212ee..bebeb7e08 100644 --- a/core/serviceuser/service.go +++ b/core/serviceuser/service.go @@ -112,6 +112,9 @@ func (s Service) Create(ctx context.Context, serviceUser ServiceUser) (ServiceUs } func (s Service) Get(ctx context.Context, id string) (ServiceUser, error) { + if !utils.IsValidUUID(id) { + return ServiceUser{}, ErrInvalidID + } return s.repo.GetByID(ctx, id) } diff --git a/core/serviceuser/service_test.go b/core/serviceuser/service_test.go index 86d2cc289..d65f7f91a 100644 --- a/core/serviceuser/service_test.go +++ b/core/serviceuser/service_test.go @@ -113,3 +113,53 @@ func TestService_Delete(t *testing.T) { }) } } + +func TestService_Get(t *testing.T) { + ctx := context.Background() + const validID = "68f86fec-eb87-49f0-9be0-8d99b00a4a9c" + + tests := []struct { + name string + id string + setup func(*mocks.Repository) + wantErrIs error + }{ + { + name: "empty id returns ErrInvalidID without hitting the repo", + id: "", + setup: func(repo *mocks.Repository) {}, + wantErrIs: serviceuser.ErrInvalidID, + }, + { + name: "non-uuid id returns ErrInvalidID without hitting the repo", + id: "not-a-uuid", + setup: func(repo *mocks.Repository) {}, + wantErrIs: serviceuser.ErrInvalidID, + }, + { + name: "valid uuid delegates to the repo", + id: validID, + setup: func(repo *mocks.Repository) { + repo.On("GetByID", ctx, validID).Return(serviceuser.ServiceUser{ID: validID}, nil) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, repo, _, _, _ := newTestService(t) + tt.setup(repo) + + _, err := svc.Get(ctx, tt.id) + if tt.wantErrIs != nil { + if !errors.Is(err, tt.wantErrIs) { + t.Errorf("Get() error = %v, want errors.Is(%v)", err, tt.wantErrIs) + } + return + } + if err != nil { + t.Errorf("Get() unexpected error = %v", err) + } + }) + } +} diff --git a/internal/api/v1beta1connect/permission_check.go b/internal/api/v1beta1connect/permission_check.go index c6ebe9eff..3af7e923b 100644 --- a/internal/api/v1beta1connect/permission_check.go +++ b/internal/api/v1beta1connect/permission_check.go @@ -28,24 +28,36 @@ func logAuditForCheck(ctx context.Context, result bool, objectID string, objectN } func (h *ConnectHandler) getPermissionName(ctx context.Context, ns, name string) (string, error) { + resolved, ok, err := h.resolvePermissionName(ctx, ns, name) + if err != nil { + return "", connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + if !ok { + return "", connect.NewError(connect.CodeNotFound, ErrNotFound) + } + return resolved, nil +} + +// resolvePermissionName looks up the canonical permission name for (namespace, name). +// ok=false means the permission is not defined; err is reserved for genuine lookup +// failures. Callers that want to treat an unknown permission as "no result" +// should use this helper; callers that want to reject the request should use +// getPermissionName which maps unknown permissions to CodeNotFound. +func (h *ConnectHandler) resolvePermissionName(ctx context.Context, ns, name string) (string, bool, error) { if ns == schema.PlatformNamespace && schema.IsPlatformPermission(name) { - return name, nil + return name, true, nil } perm, err := h.permissionService.Get(ctx, permission.AddNamespaceIfRequired(ns, name)) if err != nil { - switch { - case errors.Is(err, permission.ErrNotExist): - return "", connect.NewError(connect.CodeNotFound, ErrNotFound) - default: - return "", connect.NewError(connect.CodeInternal, ErrInternalServerError) + if errors.Is(err, permission.ErrNotExist) { + return "", false, nil } + return "", false, err } - // if the permission is on the same namespace as the object, use the name if perm.NamespaceID == ns { - return perm.Name, nil + return perm.Name, true, nil } - // else use fully qualified name(slug) - return perm.Slug, nil + return perm.Slug, true, nil } func (h *ConnectHandler) CheckFederatedResourcePermission(ctx context.Context, req *connect.Request[frontierv1beta1.CheckFederatedResourcePermissionRequest]) (*connect.Response[frontierv1beta1.CheckFederatedResourcePermissionResponse], error) { @@ -94,19 +106,38 @@ func (h *ConnectHandler) CheckFederatedResourcePermission(ctx context.Context, r } func (h *ConnectHandler) fetchAccessPairsOnResource(ctx context.Context, objectNamespace string, ids, permissions []string) ([]relation.CheckPair, error) { - checks := make([]resource.Check, 0, len(ids)*len(permissions)) + // Resolve each requested permission once, dropping unknown names and + // duplicate inputs. Unknown names produce an empty result rather than + // 4xx/5xx — see the contract on resolvePermissionName. + resolvedPerms := make([]string, 0, len(permissions)) + seen := make(map[string]struct{}, len(permissions)) + for _, p := range permissions { + resolved, ok, err := h.resolvePermissionName(ctx, objectNamespace, p) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + if !ok { + continue + } + if _, dup := seen[resolved]; dup { + continue + } + seen[resolved] = struct{}{} + resolvedPerms = append(resolvedPerms, resolved) + } + if len(resolvedPerms) == 0 || len(ids) == 0 { + return []relation.CheckPair{}, nil + } + + checks := make([]resource.Check, 0, len(ids)*len(resolvedPerms)) for _, id := range ids { - for _, permission := range permissions { - permissionName, err := h.getPermissionName(ctx, objectNamespace, permission) - if err != nil { - return nil, err - } + for _, p := range resolvedPerms { checks = append(checks, resource.Check{ Object: relation.Object{ ID: id, Namespace: objectNamespace, }, - Permission: permissionName, + Permission: p, }) } } diff --git a/internal/api/v1beta1connect/serviceuser.go b/internal/api/v1beta1connect/serviceuser.go index 76daf229d..30bd3a5b3 100644 --- a/internal/api/v1beta1connect/serviceuser.go +++ b/internal/api/v1beta1connect/serviceuser.go @@ -10,7 +10,6 @@ import ( "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/project" - "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/errors" @@ -99,7 +98,9 @@ func (h *ConnectHandler) GetServiceUser(ctx context.Context, request *connect.Re "service_user_id", serviceUserID) switch { - case err == serviceuser.ErrNotExist: + case errors.Is(err, serviceuser.ErrInvalidID): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + case errors.Is(err, serviceuser.ErrNotExist): return nil, connect.NewError(connect.CodeNotFound, ErrServiceUserNotFound) default: errorLogger.LogUnexpectedError(ctx, request, "GetServiceUser", err, @@ -475,7 +476,17 @@ func (h *ConnectHandler) ListServiceUserProjects(ctx context.Context, request *c errorLogger.LogServiceError(ctx, request, "ListServiceUserProjects", err, "service_user_id", serviceUserID, "org_id", orgID) - return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + + switch { + case errors.Is(err, project.ErrInvalidUUID), + errors.Is(err, project.ErrInvalidPrincipalType): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + default: + errorLogger.LogUnexpectedError(ctx, request, "ListServiceUserProjects", err, + "service_user_id", serviceUserID, + "org_id", orgID) + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } } var projects []*frontierv1beta1.Project @@ -501,20 +512,19 @@ func (h *ConnectHandler) ListServiceUserProjects(ctx context.Context, request *c "with_permissions", request.Msg.GetWithPermissions()) return nil, err } - for _, successCheck := range successCheckPairs { - resID := successCheck.Relation.Object.ID - - // find all permission checks on same resource - pairsForCurrentGroup := utils.Filter(successCheckPairs, func(pair relation.CheckPair) bool { - return pair.Relation.Object.ID == resID - }) - // fetch permissions - permissions := utils.Map(pairsForCurrentGroup, func(pair relation.CheckPair) string { - return pair.Relation.RelationName - }) + permsByProject := map[string][]string{} + projectOrder := make([]string, 0, len(projList)) + for _, p := range successCheckPairs { + resID := p.Relation.Object.ID + if _, seen := permsByProject[resID]; !seen { + projectOrder = append(projectOrder, resID) + } + permsByProject[resID] = append(permsByProject[resID], p.Relation.RelationName) + } + for _, resID := range projectOrder { accessPairsPb = append(accessPairsPb, &frontierv1beta1.ListServiceUserProjectsResponse_AccessPair{ ProjectId: resID, - Permissions: permissions, + Permissions: permsByProject[resID], }) } } diff --git a/internal/api/v1beta1connect/serviceuser_test.go b/internal/api/v1beta1connect/serviceuser_test.go index 1f15a2059..88b1337d7 100644 --- a/internal/api/v1beta1connect/serviceuser_test.go +++ b/internal/api/v1beta1connect/serviceuser_test.go @@ -1296,6 +1296,56 @@ func TestHandler_DeleteServiceUserToken(t *testing.T) { } } +func TestHandler_GetServiceUser(t *testing.T) { + tests := []struct { + name string + setup func(*mocks.ServiceUserService) + request *connect.Request[frontierv1beta1.GetServiceUserRequest] + errCode connect.Code + wantErr error + }{ + { + name: "maps ErrInvalidID to InvalidArgument", + setup: func(svc *mocks.ServiceUserService) { + svc.EXPECT().Get(mock.Anything, "not-a-uuid").Return(serviceuser.ServiceUser{}, serviceuser.ErrInvalidID) + }, + request: connect.NewRequest(&frontierv1beta1.GetServiceUserRequest{Id: "not-a-uuid"}), + errCode: connect.CodeInvalidArgument, + wantErr: ErrBadRequest, + }, + { + name: "maps ErrNotExist to NotFound", + setup: func(svc *mocks.ServiceUserService) { + svc.EXPECT().Get(mock.Anything, testServiceUserID).Return(serviceuser.ServiceUser{}, serviceuser.ErrNotExist) + }, + request: connect.NewRequest(&frontierv1beta1.GetServiceUserRequest{Id: testServiceUserID}), + errCode: connect.CodeNotFound, + wantErr: ErrServiceUserNotFound, + }, + { + name: "maps unexpected error to Internal", + setup: func(svc *mocks.ServiceUserService) { + svc.EXPECT().Get(mock.Anything, testServiceUserID).Return(serviceuser.ServiceUser{}, errors.New("boom")) + }, + request: connect.NewRequest(&frontierv1beta1.GetServiceUserRequest{Id: testServiceUserID}), + errCode: connect.CodeInternal, + wantErr: ErrInternalServerError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := &mocks.ServiceUserService{} + tt.setup(svc) + h := &ConnectHandler{serviceUserService: svc} + resp, err := h.GetServiceUser(context.Background(), tt.request) + assert.Nil(t, resp) + assert.Error(t, err) + assert.Equal(t, tt.errCode, connect.CodeOf(err)) + assert.Contains(t, err.Error(), tt.wantErr.Error()) + }) + } +} + func TestHandler_ListServiceUserProjects(t *testing.T) { testProjectMap := map[string]project.Project{ "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71": { @@ -1398,6 +1448,20 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { wantErr: ErrBadRequest, errCode: connect.CodeInvalidArgument, }, + { + name: "should return invalid argument when project service returns ErrInvalidUUID", + request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ + Id: "not-a-uuid", + }), + setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { + projSvc.EXPECT().List(mock.Anything, project.Filter{ + Principal: &authenticate.Principal{ID: "not-a-uuid", Type: schema.ServiceUserPrincipal}, + }).Return(nil, project.ErrInvalidUUID) + }, + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, { name: "should forward org_id to project.Filter when set", request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ @@ -1517,6 +1581,110 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { wantErr: nil, errCode: connect.Code(0), }, + { + name: "emits one access pair per project when multiple permissions succeed", + request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ + Id: "1", + WithPermissions: []string{"update", "delete"}, + }), + setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { + var projects []project.Project + for _, projectID := range testProjectIDList { + projects = append(projects, testProjectMap[projectID]) + } + projSvc.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}}).Return(projects, nil) + + permSvc.EXPECT().Get(mock.Anything, "app/project:update").Return( + permission.Permission{Name: "update", NamespaceID: "app/project"}, nil) + permSvc.EXPECT().Get(mock.Anything, "app/project:delete").Return( + permission.Permission{Name: "delete", NamespaceID: "app/project"}, nil) + + resourceSvc.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, Permission: "update"}, + {Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, Permission: "delete"}, + {Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, Permission: "update"}, + {Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, Permission: "delete"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, RelationName: "delete"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, RelationName: "delete"}, Status: true}, + }, nil) + }, + want: connect.NewResponse(&frontierv1beta1.ListServiceUserProjectsResponse{ + Projects: []*frontierv1beta1.Project{{ + Id: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", + Name: "prj-1", + Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{"email": structpb.NewStringValue("org1@org1.com")}}, + OrgId: "9f256f86-31a3-11ec-8d3d-0242ac130003", + CreatedAt: timestamppb.New(time.Time{}), + UpdatedAt: timestamppb.New(time.Time{}), + }, { + Id: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", + Name: "prj-2", + Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{"email": structpb.NewStringValue("org1@org2.com")}}, + OrgId: "9f256f86-31a3-11ec-8d3d-0242ac130003", + CreatedAt: timestamppb.New(time.Time{}), + UpdatedAt: timestamppb.New(time.Time{}), + }}, + AccessPairs: []*frontierv1beta1.ListServiceUserProjectsResponse_AccessPair{ + {ProjectId: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Permissions: []string{"update", "delete"}}, + {ProjectId: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Permissions: []string{"update", "delete"}}, + }, + }), + wantErr: nil, + errCode: connect.Code(0), + }, + { + name: "drops unknown permissions from withPermissions", + request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ + Id: "1", + WithPermissions: []string{"get", "bogus"}, + }), + setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { + var projects []project.Project + for _, projectID := range testProjectIDList { + projects = append(projects, testProjectMap[projectID]) + } + projSvc.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}}).Return(projects, nil) + + permSvc.EXPECT().Get(mock.Anything, "app/project:get").Return( + permission.Permission{Name: "get", NamespaceID: "app/project"}, nil) + permSvc.EXPECT().Get(mock.Anything, "app/project:bogus").Return( + permission.Permission{}, permission.ErrNotExist) + + resourceSvc.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, Permission: "get"}, + {Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, Permission: "get"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, RelationName: "get"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, RelationName: "get"}, Status: true}, + }, nil) + }, + want: connect.NewResponse(&frontierv1beta1.ListServiceUserProjectsResponse{ + Projects: []*frontierv1beta1.Project{{ + Id: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", + Name: "prj-1", + Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{"email": structpb.NewStringValue("org1@org1.com")}}, + OrgId: "9f256f86-31a3-11ec-8d3d-0242ac130003", + CreatedAt: timestamppb.New(time.Time{}), + UpdatedAt: timestamppb.New(time.Time{}), + }, { + Id: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", + Name: "prj-2", + Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{"email": structpb.NewStringValue("org1@org2.com")}}, + OrgId: "9f256f86-31a3-11ec-8d3d-0242ac130003", + CreatedAt: timestamppb.New(time.Time{}), + UpdatedAt: timestamppb.New(time.Time{}), + }}, + AccessPairs: []*frontierv1beta1.ListServiceUserProjectsResponse_AccessPair{ + {ProjectId: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Permissions: []string{"get"}}, + {ProjectId: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Permissions: []string{"get"}}, + }, + }), + wantErr: nil, + errCode: connect.Code(0), + }, } for _, tt := range tests { diff --git a/internal/api/v1beta1connect/user.go b/internal/api/v1beta1connect/user.go index 45ff7e43b..5a1298e58 100644 --- a/internal/api/v1beta1connect/user.go +++ b/internal/api/v1beta1connect/user.go @@ -3,13 +3,12 @@ package v1beta1connect import ( "context" "fmt" + "log/slog" "net/mail" "strings" "connectrpc.com/connect" - "log/slog" - "github.com/pkg/errors" "github.com/raystack/frontier/core/audit" "github.com/raystack/frontier/core/authenticate" @@ -17,7 +16,6 @@ import ( "github.com/raystack/frontier/core/membership" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/project" - "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/internal/store/postgres" @@ -553,20 +551,19 @@ func (h *ConnectHandler) ListCurrentUserGroups(ctx context.Context, request *con "org_id", request.Msg.GetOrgId()) return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) } - for _, successCheck := range successCheckPairs { - resID := successCheck.Relation.Object.ID - - // find all permission checks on same resource - pairsForCurrentGroup := utils.Filter(successCheckPairs, func(pair relation.CheckPair) bool { - return pair.Relation.Object.ID == resID - }) - // fetch permissions - permissions := utils.Map(pairsForCurrentGroup, func(pair relation.CheckPair) string { - return pair.Relation.RelationName - }) + permsByGroup := map[string][]string{} + groupOrder := make([]string, 0, len(groupsList)) + for _, p := range successCheckPairs { + resID := p.Relation.Object.ID + if _, seen := permsByGroup[resID]; !seen { + groupOrder = append(groupOrder, resID) + } + permsByGroup[resID] = append(permsByGroup[resID], p.Relation.RelationName) + } + for _, resID := range groupOrder { accessPairsPb = append(accessPairsPb, &frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair{ GroupId: resID, - Permissions: permissions, + Permissions: permsByGroup[resID], }) } } @@ -876,10 +873,11 @@ func (h *ConnectHandler) ListProjectsByUser(ctx context.Context, request *connec "user_id", userID) switch { + case errors.Is(err, project.ErrInvalidUUID), + errors.Is(err, project.ErrInvalidPrincipalType): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) case errors.Is(err, user.ErrNotExist): return nil, connect.NewError(connect.CodeNotFound, ErrNotFound) - case errors.Is(err, user.ErrInvalidUUID): - return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) default: errorLogger.LogUnexpectedError(ctx, request, "ListProjectsByUser", err, "user_id", userID) @@ -949,20 +947,22 @@ func (h *ConnectHandler) ListProjectsByCurrentUser(ctx context.Context, request "org_id", request.Msg.GetOrgId()) return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) } - for _, successCheck := range successCheckPairs { - resID := successCheck.Relation.Object.ID - - // find all permission checks on same resource - pairsForCurrentGroup := utils.Filter(successCheckPairs, func(pair relation.CheckPair) bool { - return pair.Relation.Object.ID == resID - }) - // fetch permissions - permissions := utils.Map(pairsForCurrentGroup, func(pair relation.CheckPair) string { - return pair.Relation.RelationName - }) + // Group permissions by project id, emit one access pair per project in + // first-seen order. successCheckPairs is unique by (resID, permName) so + // no per-permission dedup is needed here. + permsByProject := map[string][]string{} + projectOrder := make([]string, 0, len(projList)) + for _, p := range successCheckPairs { + resID := p.Relation.Object.ID + if _, seen := permsByProject[resID]; !seen { + projectOrder = append(projectOrder, resID) + } + permsByProject[resID] = append(permsByProject[resID], p.Relation.RelationName) + } + for _, resID := range projectOrder { accessPairsPb = append(accessPairsPb, &frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair{ ProjectId: resID, - Permissions: permissions, + Permissions: permsByProject[resID], }) } } diff --git a/internal/api/v1beta1connect/user_test.go b/internal/api/v1beta1connect/user_test.go index 73e475170..e5801b888 100644 --- a/internal/api/v1beta1connect/user_test.go +++ b/internal/api/v1beta1connect/user_test.go @@ -10,7 +10,10 @@ import ( "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/group" "github.com/raystack/frontier/core/organization" + "github.com/raystack/frontier/core/permission" "github.com/raystack/frontier/core/project" + "github.com/raystack/frontier/core/relation" + "github.com/raystack/frontier/core/resource" "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/internal/api/v1beta1connect/mocks" @@ -1209,6 +1212,110 @@ func TestConnectHandler_ListCurrentUserGroups(t *testing.T) { } } +func TestConnectHandler_ListCurrentUserGroups_AccessPairs(t *testing.T) { + const ( + groupA = "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71" + groupB = "c7772c63-fca4-4c7c-bf93-c8f85115de4b" + ) + principal := authenticate.Principal{ + ID: "9f256f86-31a3-11ec-8d3d-0242ac130003", + Type: schema.UserPrincipal, + User: &user.User{ID: "9f256f86-31a3-11ec-8d3d-0242ac130003"}, + } + + resolvedPermission := func(name string) permission.Permission { + return permission.Permission{Name: name, NamespaceID: schema.GroupNamespace} + } + + tests := []struct { + title string + withPermissions []string + setup func(*mocks.PermissionService, *mocks.ResourceService) + wantAccessPairs []*frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair + }{ + { + title: "emits one access pair per group when multiple permissions succeed", + withPermissions: []string{"update", "delete"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/group:update").Return(resolvedPermission("update"), nil) + perm.EXPECT().Get(mock.Anything, "app/group:delete").Return(resolvedPermission("delete"), nil) + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, Permission: "update"}, + {Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, Permission: "delete"}, + {Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, Permission: "update"}, + {Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, Permission: "delete"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, RelationName: "delete"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, RelationName: "delete"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair{ + {GroupId: groupA, Permissions: []string{"update", "delete"}}, + {GroupId: groupB, Permissions: []string{"update", "delete"}}, + }, + }, + { + title: "drops unknown permissions and returns access pairs only for the known ones", + withPermissions: []string{"update", "bogus"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/group:update").Return(resolvedPermission("update"), nil) + perm.EXPECT().Get(mock.Anything, "app/group:bogus").Return(permission.Permission{}, permission.ErrNotExist) + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, Permission: "update"}, + {Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, Permission: "update"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, RelationName: "update"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair{ + {GroupId: groupA, Permissions: []string{"update"}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + mockGroupSrv := new(mocks.GroupService) + mockAuthnSrv := new(mocks.AuthnService) + mockPermissionSrv := new(mocks.PermissionService) + mockResourceSrv := new(mocks.ResourceService) + + mockAuthnSrv.EXPECT().GetPrincipal(mock.Anything).Return(principal, nil) + mockGroupSrv.EXPECT().List(mock.Anything, mock.MatchedBy(func(f group.Filter) bool { + return f.Principal != nil && *f.Principal == principal + })).Return([]group.Group{ + {ID: groupA, OrganizationID: "org-1"}, + {ID: groupB, OrganizationID: "org-1"}, + }, nil) + if tt.setup != nil { + tt.setup(mockPermissionSrv, mockResourceSrv) + } + + handler := &ConnectHandler{ + groupService: mockGroupSrv, + authnService: mockAuthnSrv, + permissionService: mockPermissionSrv, + resourceService: mockResourceSrv, + } + + req := connect.NewRequest(&frontierv1beta1.ListCurrentUserGroupsRequest{ + WithPermissions: tt.withPermissions, + }) + resp, err := handler.ListCurrentUserGroups(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tt.wantAccessPairs, resp.Msg.GetAccessPairs()) + + mockGroupSrv.AssertExpectations(t) + mockAuthnSrv.AssertExpectations(t) + mockPermissionSrv.AssertExpectations(t) + mockResourceSrv.AssertExpectations(t) + }) + } +} + func TestConnectHandler_ListOrganizationsByUser(t *testing.T) { userID := uuid.New().String() @@ -1634,14 +1741,23 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { err: connect.CodeNotFound, }, { - title: "should return bad request error for invalid user ID", + title: "should return bad request error when project service returns ErrInvalidUUID", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "invalid-id", Type: schema.UserPrincipal}}).Return(nil, user.ErrInvalidUUID) + ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "invalid-id", Type: schema.UserPrincipal}}).Return(nil, project.ErrInvalidUUID) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "invalid-id"}, want: nil, err: connect.CodeInvalidArgument, }, + { + title: "should return bad request error when project service returns ErrInvalidPrincipalType", + setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { + ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}}).Return(nil, project.ErrInvalidPrincipalType) + }, + req: &frontierv1beta1.ListProjectsByUserRequest{Id: "user-1"}, + want: nil, + err: connect.CodeInvalidArgument, + }, { title: "should return internal error for project service failure", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { @@ -1914,3 +2030,147 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { }) } } + +func TestConnectHandler_ListProjectsByCurrentUser_AccessPairs(t *testing.T) { + const ( + projA = "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71" + projB = "c7772c63-fca4-4c7c-bf93-c8f85115de4b" + ) + principal := authenticate.Principal{ + ID: "9f256f86-31a3-11ec-8d3d-0242ac130003", + Type: schema.UserPrincipal, + User: &user.User{ID: "9f256f86-31a3-11ec-8d3d-0242ac130003"}, + } + + resolvedPermission := func(name string) permission.Permission { + return permission.Permission{Name: name, NamespaceID: schema.ProjectNamespace} + } + + tests := []struct { + title string + withPermissions []string + setup func(*mocks.PermissionService, *mocks.ResourceService) + wantAccessPairs []*frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair + wantErr connect.Code + }{ + { + title: "emits one access pair per project when multiple permissions succeed", + withPermissions: []string{"update", "delete"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/project:update").Return(resolvedPermission("update"), nil) + perm.EXPECT().Get(mock.Anything, "app/project:delete").Return(resolvedPermission("delete"), nil) + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "delete"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "delete"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "delete"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, RelationName: "delete"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair{ + {ProjectId: projA, Permissions: []string{"update", "delete"}}, + {ProjectId: projB, Permissions: []string{"update", "delete"}}, + }, + }, + { + title: "drops unknown permissions and returns access pairs only for the known ones", + withPermissions: []string{"update", "bogus"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/project:update").Return(resolvedPermission("update"), nil) + perm.EXPECT().Get(mock.Anything, "app/project:bogus").Return(permission.Permission{}, permission.ErrNotExist) + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "update"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair{ + {ProjectId: projA, Permissions: []string{"update"}}, + }, + }, + { + title: "returns empty access pairs when every requested permission is unknown", + withPermissions: []string{"bogus1", "bogus2"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/project:bogus1").Return(permission.Permission{}, permission.ErrNotExist) + perm.EXPECT().Get(mock.Anything, "app/project:bogus2").Return(permission.Permission{}, permission.ErrNotExist) + // resourceService.BatchCheck must NOT be called. + }, + wantAccessPairs: nil, + }, + { + title: "deduplicates repeated permission inputs", + withPermissions: []string{"update", "update", "delete"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/project:update").Return(resolvedPermission("update"), nil).Times(2) + perm.EXPECT().Get(mock.Anything, "app/project:delete").Return(resolvedPermission("delete"), nil) + // Each (project, permission) appears exactly once even though "update" was requested twice. + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "delete"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "delete"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "delete"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, RelationName: "delete"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair{ + {ProjectId: projA, Permissions: []string{"update", "delete"}}, + {ProjectId: projB, Permissions: []string{"update", "delete"}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + mockProjectSrv := new(mocks.ProjectService) + mockAuthnSrv := new(mocks.AuthnService) + mockPermissionSrv := new(mocks.PermissionService) + mockResourceSrv := new(mocks.ResourceService) + + mockAuthnSrv.EXPECT().GetPrincipal(mock.Anything).Return(principal, nil) + mockProjectSrv.EXPECT().List(mock.Anything, mock.MatchedBy(func(f project.Filter) bool { + return f.Principal != nil && *f.Principal == principal + })).Return([]project.Project{ + {ID: projA, Organization: organization.Organization{ID: "org-1"}}, + {ID: projB, Organization: organization.Organization{ID: "org-1"}}, + }, nil) + if tt.setup != nil { + tt.setup(mockPermissionSrv, mockResourceSrv) + } + + handler := &ConnectHandler{ + projectService: mockProjectSrv, + authnService: mockAuthnSrv, + permissionService: mockPermissionSrv, + resourceService: mockResourceSrv, + } + + req := connect.NewRequest(&frontierv1beta1.ListProjectsByCurrentUserRequest{ + WithPermissions: tt.withPermissions, + }) + resp, err := handler.ListProjectsByCurrentUser(context.Background(), req) + if tt.wantErr != connect.Code(0) { + assert.Nil(t, resp) + assert.Equal(t, tt.wantErr, connect.CodeOf(err)) + return + } + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tt.wantAccessPairs, resp.Msg.GetAccessPairs()) + + mockProjectSrv.AssertExpectations(t) + mockAuthnSrv.AssertExpectations(t) + mockPermissionSrv.AssertExpectations(t) + mockResourceSrv.AssertExpectations(t) + }) + } +}