diff --git a/core/deleter/mocks/membership_service.go b/core/deleter/mocks/membership_service.go index 9275835e7..2f2c6ccda 100644 --- a/core/deleter/mocks/membership_service.go +++ b/core/deleter/mocks/membership_service.go @@ -5,6 +5,10 @@ package mocks import ( context "context" + authenticate "github.com/raystack/frontier/core/authenticate" + + membership "github.com/raystack/frontier/core/membership" + mock "github.com/stretchr/testify/mock" ) @@ -21,6 +25,67 @@ func (_m *MembershipService) EXPECT() *MembershipService_Expecter { return &MembershipService_Expecter{mock: &_m.Mock} } +// ListResourcesByPrincipal provides a mock function with given fields: ctx, principal, resourceType, filter +func (_m *MembershipService) ListResourcesByPrincipal(ctx context.Context, principal authenticate.Principal, resourceType string, filter membership.ResourceFilter) ([]string, error) { + ret := _m.Called(ctx, principal, resourceType, filter) + + if len(ret) == 0 { + panic("no return value specified for ListResourcesByPrincipal") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string, membership.ResourceFilter) ([]string, error)); ok { + return rf(ctx, principal, resourceType, filter) + } + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string, membership.ResourceFilter) []string); ok { + r0 = rf(ctx, principal, resourceType, filter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, string, membership.ResourceFilter) error); ok { + r1 = rf(ctx, principal, resourceType, filter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MembershipService_ListResourcesByPrincipal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListResourcesByPrincipal' +type MembershipService_ListResourcesByPrincipal_Call struct { + *mock.Call +} + +// ListResourcesByPrincipal is a helper method to define mock.On call +// - ctx context.Context +// - principal authenticate.Principal +// - resourceType string +// - filter membership.ResourceFilter +func (_e *MembershipService_Expecter) ListResourcesByPrincipal(ctx interface{}, principal interface{}, resourceType interface{}, filter interface{}) *MembershipService_ListResourcesByPrincipal_Call { + return &MembershipService_ListResourcesByPrincipal_Call{Call: _e.mock.On("ListResourcesByPrincipal", ctx, principal, resourceType, filter)} +} + +func (_c *MembershipService_ListResourcesByPrincipal_Call) Run(run func(ctx context.Context, principal authenticate.Principal, resourceType string, filter membership.ResourceFilter)) *MembershipService_ListResourcesByPrincipal_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(string), args[3].(membership.ResourceFilter)) + }) + return _c +} + +func (_c *MembershipService_ListResourcesByPrincipal_Call) Return(_a0 []string, _a1 error) *MembershipService_ListResourcesByPrincipal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MembershipService_ListResourcesByPrincipal_Call) RunAndReturn(run func(context.Context, authenticate.Principal, string, membership.ResourceFilter) ([]string, error)) *MembershipService_ListResourcesByPrincipal_Call { + _c.Call.Return(run) + return _c +} + // OnGroupDeleted provides a mock function with given fields: ctx, groupID func (_m *MembershipService) OnGroupDeleted(ctx context.Context, groupID string) error { ret := _m.Called(ctx, groupID) diff --git a/core/deleter/mocks/organization_service.go b/core/deleter/mocks/organization_service.go index e597d36a7..3ef7c9532 100644 --- a/core/deleter/mocks/organization_service.go +++ b/core/deleter/mocks/organization_service.go @@ -5,8 +5,6 @@ package mocks import ( context "context" - authenticate "github.com/raystack/frontier/core/authenticate" - mock "github.com/stretchr/testify/mock" organization "github.com/raystack/frontier/core/organization" @@ -129,66 +127,6 @@ func (_c *OrganizationService_Get_Call) RunAndReturn(run func(context.Context, s return _c } -// ListByUser provides a mock function with given fields: ctx, principal, f -func (_m *OrganizationService) ListByUser(ctx context.Context, principal authenticate.Principal, f organization.Filter) ([]organization.Organization, error) { - ret := _m.Called(ctx, principal, f) - - if len(ret) == 0 { - panic("no return value specified for ListByUser") - } - - var r0 []organization.Organization - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, organization.Filter) ([]organization.Organization, error)); ok { - return rf(ctx, principal, f) - } - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, organization.Filter) []organization.Organization); ok { - r0 = rf(ctx, principal, f) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]organization.Organization) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, organization.Filter) error); ok { - r1 = rf(ctx, principal, f) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// OrganizationService_ListByUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListByUser' -type OrganizationService_ListByUser_Call struct { - *mock.Call -} - -// ListByUser is a helper method to define mock.On call -// - ctx context.Context -// - principal authenticate.Principal -// - f organization.Filter -func (_e *OrganizationService_Expecter) ListByUser(ctx interface{}, principal interface{}, f interface{}) *OrganizationService_ListByUser_Call { - return &OrganizationService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, f)} -} - -func (_c *OrganizationService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, f organization.Filter)) *OrganizationService_ListByUser_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(organization.Filter)) - }) - return _c -} - -func (_c *OrganizationService_ListByUser_Call) Return(_a0 []organization.Organization, _a1 error) *OrganizationService_ListByUser_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *OrganizationService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, organization.Filter) ([]organization.Organization, error)) *OrganizationService_ListByUser_Call { - _c.Call.Return(run) - return _c -} - // RemoveUsers provides a mock function with given fields: ctx, orgID, userIDs func (_m *OrganizationService) RemoveUsers(ctx context.Context, orgID string, userIDs []string) error { ret := _m.Called(ctx, orgID, userIDs) diff --git a/core/deleter/service.go b/core/deleter/service.go index 0ccaeefeb..4cbee75b3 100644 --- a/core/deleter/service.go +++ b/core/deleter/service.go @@ -20,6 +20,7 @@ import ( "github.com/google/uuid" "github.com/raystack/frontier/core/invitation" + "github.com/raystack/frontier/core/membership" "github.com/raystack/frontier/core/policy" "github.com/raystack/frontier/core/role" @@ -44,7 +45,6 @@ type OrganizationService interface { Get(ctx context.Context, id string) (organization.Organization, error) DeleteModel(ctx context.Context, id string) error RemoveUsers(ctx context.Context, orgID string, userIDs []string) error - ListByUser(ctx context.Context, principal authenticate.Principal, f organization.Filter) ([]organization.Organization, error) } type RoleService interface { @@ -71,6 +71,7 @@ type GroupService interface { type MembershipService interface { OnGroupDeleted(ctx context.Context, groupID string) error + ListResourcesByPrincipal(ctx context.Context, principal authenticate.Principal, resourceType string, filter membership.ResourceFilter) ([]string, error) } type InvitationService interface { @@ -378,17 +379,19 @@ func (d Service) RemoveUsersFromOrg(ctx context.Context, orgID string, userIDs [ return d.orgService.RemoveUsers(ctx, orgID, userIDs) } +// DeleteUser visits every org the user has a policy on (disabled orgs too), +// otherwise userService.Delete would leave orphan policy rows behind. func (d Service) DeleteUser(ctx context.Context, userID string) error { - userOrgs, err := d.orgService.ListByUser(ctx, authenticate.Principal{ + orgIDs, err := d.membershipService.ListResourcesByPrincipal(ctx, authenticate.Principal{ ID: userID, Type: schema.UserPrincipal, - }, organization.Filter{}) + }, schema.OrganizationNamespace, membership.ResourceFilter{}) if err != nil { return err } - for _, org := range userOrgs { - if err = d.RemoveUsersFromOrg(ctx, org.ID, []string{userID}); err != nil { - return fmt.Errorf("failed to delete user from org[%s]: %w", org.Name, err) + for _, orgID := range orgIDs { + if err = d.RemoveUsersFromOrg(ctx, orgID, []string{userID}); err != nil { + return fmt.Errorf("failed to delete user from org[%s]: %w", orgID, err) } } return d.userService.Delete(ctx, userID) diff --git a/core/deleter/service_test.go b/core/deleter/service_test.go index 19da0b94f..fb16ec2ae 100644 --- a/core/deleter/service_test.go +++ b/core/deleter/service_test.go @@ -17,6 +17,7 @@ import ( "github.com/raystack/frontier/core/resource" "github.com/raystack/frontier/core/role" "github.com/raystack/frontier/core/serviceuser" + "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -259,7 +260,7 @@ func TestDeleteUser(t *testing.T) { t.Run("removes user from all orgs then deletes", func(t *testing.T) { orgSvc, projSvc, resSvc, grpSvc, mbrSvc, polSvc, roleSvc, invSvc, usrSvc, suSvc, custSvc, subSvc, invocSvc := newMocks(t) - orgSvc.EXPECT().ListByUser(mock.Anything, mock.Anything, mock.Anything). + mbrSvc.EXPECT().ListResourcesByPrincipal(mock.Anything, mock.Anything, schema.OrganizationNamespace, mock.Anything). Return(nil, nil) usrSvc.EXPECT().Delete(mock.Anything, "user-1").Return(nil) diff --git a/core/domain/mocks/membership_service.go b/core/domain/mocks/membership_service.go new file mode 100644 index 000000000..7054e2668 --- /dev/null +++ b/core/domain/mocks/membership_service.go @@ -0,0 +1,151 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + authenticate "github.com/raystack/frontier/core/authenticate" + + membership "github.com/raystack/frontier/core/membership" + + mock "github.com/stretchr/testify/mock" +) + +// MembershipService is an autogenerated mock type for the MembershipService type +type MembershipService struct { + mock.Mock +} + +type MembershipService_Expecter struct { + mock *mock.Mock +} + +func (_m *MembershipService) EXPECT() *MembershipService_Expecter { + return &MembershipService_Expecter{mock: &_m.Mock} +} + +// AddOrganizationMember provides a mock function with given fields: ctx, orgID, principalID, principalType, roleID +func (_m *MembershipService) AddOrganizationMember(ctx context.Context, orgID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, orgID, principalID, principalType, roleID) + + if len(ret) == 0 { + panic("no return value specified for AddOrganizationMember") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, orgID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_AddOrganizationMember_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddOrganizationMember' +type MembershipService_AddOrganizationMember_Call struct { + *mock.Call +} + +// AddOrganizationMember is a helper method to define mock.On call +// - ctx context.Context +// - orgID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) AddOrganizationMember(ctx interface{}, orgID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_AddOrganizationMember_Call { + return &MembershipService_AddOrganizationMember_Call{Call: _e.mock.On("AddOrganizationMember", ctx, orgID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_AddOrganizationMember_Call) Run(run func(ctx context.Context, orgID string, principalID string, principalType string, roleID string)) *MembershipService_AddOrganizationMember_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_AddOrganizationMember_Call) Return(_a0 error) *MembershipService_AddOrganizationMember_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_AddOrganizationMember_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_AddOrganizationMember_Call { + _c.Call.Return(run) + return _c +} + +// ListResourcesByPrincipal provides a mock function with given fields: ctx, principal, resourceType, filter +func (_m *MembershipService) ListResourcesByPrincipal(ctx context.Context, principal authenticate.Principal, resourceType string, filter membership.ResourceFilter) ([]string, error) { + ret := _m.Called(ctx, principal, resourceType, filter) + + if len(ret) == 0 { + panic("no return value specified for ListResourcesByPrincipal") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string, membership.ResourceFilter) ([]string, error)); ok { + return rf(ctx, principal, resourceType, filter) + } + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string, membership.ResourceFilter) []string); ok { + r0 = rf(ctx, principal, resourceType, filter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, string, membership.ResourceFilter) error); ok { + r1 = rf(ctx, principal, resourceType, filter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MembershipService_ListResourcesByPrincipal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListResourcesByPrincipal' +type MembershipService_ListResourcesByPrincipal_Call struct { + *mock.Call +} + +// ListResourcesByPrincipal is a helper method to define mock.On call +// - ctx context.Context +// - principal authenticate.Principal +// - resourceType string +// - filter membership.ResourceFilter +func (_e *MembershipService_Expecter) ListResourcesByPrincipal(ctx interface{}, principal interface{}, resourceType interface{}, filter interface{}) *MembershipService_ListResourcesByPrincipal_Call { + return &MembershipService_ListResourcesByPrincipal_Call{Call: _e.mock.On("ListResourcesByPrincipal", ctx, principal, resourceType, filter)} +} + +func (_c *MembershipService_ListResourcesByPrincipal_Call) Run(run func(ctx context.Context, principal authenticate.Principal, resourceType string, filter membership.ResourceFilter)) *MembershipService_ListResourcesByPrincipal_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(string), args[3].(membership.ResourceFilter)) + }) + return _c +} + +func (_c *MembershipService_ListResourcesByPrincipal_Call) Return(_a0 []string, _a1 error) *MembershipService_ListResourcesByPrincipal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MembershipService_ListResourcesByPrincipal_Call) RunAndReturn(run func(context.Context, authenticate.Principal, string, membership.ResourceFilter) ([]string, error)) *MembershipService_ListResourcesByPrincipal_Call { + _c.Call.Return(run) + return _c +} + +// NewMembershipService creates a new instance of MembershipService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMembershipService(t interface { + mock.TestingT + Cleanup(func()) +}) *MembershipService { + mock := &MembershipService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/domain/mocks/org_service.go b/core/domain/mocks/org_service.go new file mode 100644 index 000000000..bd341d2b1 --- /dev/null +++ b/core/domain/mocks/org_service.go @@ -0,0 +1,95 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + organization "github.com/raystack/frontier/core/organization" +) + +// OrgService is an autogenerated mock type for the OrgService type +type OrgService struct { + mock.Mock +} + +type OrgService_Expecter struct { + mock *mock.Mock +} + +func (_m *OrgService) EXPECT() *OrgService_Expecter { + return &OrgService_Expecter{mock: &_m.Mock} +} + +// Get provides a mock function with given fields: ctx, id +func (_m *OrgService) Get(ctx context.Context, id string) (organization.Organization, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 organization.Organization + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (organization.Organization, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) organization.Organization); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(organization.Organization) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// OrgService_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type OrgService_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *OrgService_Expecter) Get(ctx interface{}, id interface{}) *OrgService_Get_Call { + return &OrgService_Get_Call{Call: _e.mock.On("Get", ctx, id)} +} + +func (_c *OrgService_Get_Call) Run(run func(ctx context.Context, id string)) *OrgService_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *OrgService_Get_Call) Return(_a0 organization.Organization, _a1 error) *OrgService_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *OrgService_Get_Call) RunAndReturn(run func(context.Context, string) (organization.Organization, error)) *OrgService_Get_Call { + _c.Call.Return(run) + return _c +} + +// NewOrgService creates a new instance of OrgService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewOrgService(t interface { + mock.TestingT + Cleanup(func()) +}) *OrgService { + mock := &OrgService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/domain/mocks/repository.go b/core/domain/mocks/repository.go new file mode 100644 index 000000000..daab1a2ad --- /dev/null +++ b/core/domain/mocks/repository.go @@ -0,0 +1,360 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + domain "github.com/raystack/frontier/core/domain" + mock "github.com/stretchr/testify/mock" +) + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, _a1 +func (_m *Repository) Create(ctx context.Context, _a1 domain.Domain) (domain.Domain, error) { + ret := _m.Called(ctx, _a1) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 domain.Domain + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, domain.Domain) (domain.Domain, error)); ok { + return rf(ctx, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, domain.Domain) domain.Domain); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Get(0).(domain.Domain) + } + + if rf, ok := ret.Get(1).(func(context.Context, domain.Domain) error); ok { + r1 = rf(ctx, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type Repository_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - _a1 domain.Domain +func (_e *Repository_Expecter) Create(ctx interface{}, _a1 interface{}) *Repository_Create_Call { + return &Repository_Create_Call{Call: _e.mock.On("Create", ctx, _a1)} +} + +func (_c *Repository_Create_Call) Run(run func(ctx context.Context, _a1 domain.Domain)) *Repository_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(domain.Domain)) + }) + return _c +} + +func (_c *Repository_Create_Call) Return(_a0 domain.Domain, _a1 error) *Repository_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, domain.Domain) (domain.Domain, error)) *Repository_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, id +func (_m *Repository) Delete(ctx context.Context, id string) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type Repository_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) Delete(ctx interface{}, id interface{}) *Repository_Delete_Call { + return &Repository_Delete_Call{Call: _e.mock.On("Delete", ctx, id)} +} + +func (_c *Repository_Delete_Call) Run(run func(ctx context.Context, id string)) *Repository_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_Delete_Call) Return(_a0 error) *Repository_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_Delete_Call) RunAndReturn(run func(context.Context, string) error) *Repository_Delete_Call { + _c.Call.Return(run) + return _c +} + +// DeleteExpiredDomainRequests provides a mock function with given fields: ctx +func (_m *Repository) DeleteExpiredDomainRequests(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for DeleteExpiredDomainRequests") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_DeleteExpiredDomainRequests_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteExpiredDomainRequests' +type Repository_DeleteExpiredDomainRequests_Call struct { + *mock.Call +} + +// DeleteExpiredDomainRequests is a helper method to define mock.On call +// - ctx context.Context +func (_e *Repository_Expecter) DeleteExpiredDomainRequests(ctx interface{}) *Repository_DeleteExpiredDomainRequests_Call { + return &Repository_DeleteExpiredDomainRequests_Call{Call: _e.mock.On("DeleteExpiredDomainRequests", ctx)} +} + +func (_c *Repository_DeleteExpiredDomainRequests_Call) Run(run func(ctx context.Context)) *Repository_DeleteExpiredDomainRequests_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Repository_DeleteExpiredDomainRequests_Call) Return(_a0 error) *Repository_DeleteExpiredDomainRequests_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_DeleteExpiredDomainRequests_Call) RunAndReturn(run func(context.Context) error) *Repository_DeleteExpiredDomainRequests_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function with given fields: ctx, id +func (_m *Repository) Get(ctx context.Context, id string) (domain.Domain, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 domain.Domain + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (domain.Domain, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) domain.Domain); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(domain.Domain) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type Repository_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) Get(ctx interface{}, id interface{}) *Repository_Get_Call { + return &Repository_Get_Call{Call: _e.mock.On("Get", ctx, id)} +} + +func (_c *Repository_Get_Call) Run(run func(ctx context.Context, id string)) *Repository_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_Get_Call) Return(_a0 domain.Domain, _a1 error) *Repository_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_Get_Call) RunAndReturn(run func(context.Context, string) (domain.Domain, error)) *Repository_Get_Call { + _c.Call.Return(run) + return _c +} + +// List provides a mock function with given fields: ctx, flt +func (_m *Repository) List(ctx context.Context, flt domain.Filter) ([]domain.Domain, error) { + ret := _m.Called(ctx, flt) + + if len(ret) == 0 { + panic("no return value specified for List") + } + + var r0 []domain.Domain + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, domain.Filter) ([]domain.Domain, error)); ok { + return rf(ctx, flt) + } + if rf, ok := ret.Get(0).(func(context.Context, domain.Filter) []domain.Domain); ok { + r0 = rf(ctx, flt) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]domain.Domain) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, domain.Filter) error); ok { + r1 = rf(ctx, flt) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type Repository_List_Call struct { + *mock.Call +} + +// List is a helper method to define mock.On call +// - ctx context.Context +// - flt domain.Filter +func (_e *Repository_Expecter) List(ctx interface{}, flt interface{}) *Repository_List_Call { + return &Repository_List_Call{Call: _e.mock.On("List", ctx, flt)} +} + +func (_c *Repository_List_Call) Run(run func(ctx context.Context, flt domain.Filter)) *Repository_List_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(domain.Filter)) + }) + return _c +} + +func (_c *Repository_List_Call) Return(_a0 []domain.Domain, _a1 error) *Repository_List_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_List_Call) RunAndReturn(run func(context.Context, domain.Filter) ([]domain.Domain, error)) *Repository_List_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: ctx, _a1 +func (_m *Repository) Update(ctx context.Context, _a1 domain.Domain) (domain.Domain, error) { + ret := _m.Called(ctx, _a1) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 domain.Domain + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, domain.Domain) (domain.Domain, error)); ok { + return rf(ctx, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, domain.Domain) domain.Domain); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Get(0).(domain.Domain) + } + + if rf, ok := ret.Get(1).(func(context.Context, domain.Domain) error); ok { + r1 = rf(ctx, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type Repository_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - _a1 domain.Domain +func (_e *Repository_Expecter) Update(ctx interface{}, _a1 interface{}) *Repository_Update_Call { + return &Repository_Update_Call{Call: _e.mock.On("Update", ctx, _a1)} +} + +func (_c *Repository_Update_Call) Run(run func(ctx context.Context, _a1 domain.Domain)) *Repository_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(domain.Domain)) + }) + return _c +} + +func (_c *Repository_Update_Call) Return(_a0 domain.Domain, _a1 error) *Repository_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_Update_Call) RunAndReturn(run func(context.Context, domain.Domain) (domain.Domain, error)) *Repository_Update_Call { + _c.Call.Return(run) + return _c +} + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/domain/mocks/user_service.go b/core/domain/mocks/user_service.go new file mode 100644 index 000000000..2b3f7b4de --- /dev/null +++ b/core/domain/mocks/user_service.go @@ -0,0 +1,95 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + user "github.com/raystack/frontier/core/user" +) + +// UserService is an autogenerated mock type for the UserService type +type UserService struct { + mock.Mock +} + +type UserService_Expecter struct { + mock *mock.Mock +} + +func (_m *UserService) EXPECT() *UserService_Expecter { + return &UserService_Expecter{mock: &_m.Mock} +} + +// GetByID provides a mock function with given fields: ctx, id +func (_m *UserService) GetByID(ctx context.Context, id string) (user.User, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 user.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (user.User, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) user.User); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(user.User) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UserService_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type UserService_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *UserService_Expecter) GetByID(ctx interface{}, id interface{}) *UserService_GetByID_Call { + return &UserService_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *UserService_GetByID_Call) Run(run func(ctx context.Context, id string)) *UserService_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *UserService_GetByID_Call) Return(_a0 user.User, _a1 error) *UserService_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *UserService_GetByID_Call) RunAndReturn(run func(context.Context, string) (user.User, error)) *UserService_GetByID_Call { + _c.Call.Return(run) + return _c +} + +// NewUserService creates a new instance of UserService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewUserService(t interface { + mock.TestingT + Cleanup(func()) +}) *UserService { + mock := &UserService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/domain/service.go b/core/domain/service.go index 3b6cbd653..ff622f127 100644 --- a/core/domain/service.go +++ b/core/domain/service.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net" + "slices" "strings" "time" @@ -28,11 +29,11 @@ type UserService interface { type OrgService interface { Get(ctx context.Context, id string) (organization.Organization, error) - ListByUser(ctx context.Context, principal authenticate.Principal, filter organization.Filter) ([]organization.Organization, error) } type MembershipService interface { AddOrganizationMember(ctx context.Context, orgID, principalID, principalType, roleID string) error + ListResourcesByPrincipal(ctx context.Context, principal authenticate.Principal, resourceType string, filter membership.ResourceFilter) ([]string, error) } type Service struct { @@ -185,25 +186,17 @@ func (s Service) ListJoinableOrgsByDomain(ctx context.Context, email string) ([] return nil, err } - userOrgs, err := s.orgService.ListByUser(ctx, authenticate.Principal{ + memberOrgIDs, err := s.membershipService.ListResourcesByPrincipal(ctx, authenticate.Principal{ ID: currUser.ID, Type: schema.UserPrincipal, - }, organization.Filter{}) + }, schema.OrganizationNamespace, membership.ResourceFilter{}) if err != nil { return nil, err } var orgIDs []string - var alreadyMember bool for _, domain := range domains { - alreadyMember = false - for _, org := range userOrgs { - if org.ID == domain.OrgID { - alreadyMember = true - break - } - } - if !alreadyMember { + if !slices.Contains(memberOrgIDs, domain.OrgID) { orgIDs = append(orgIDs, domain.OrgID) } } diff --git a/core/domain/service_test.go b/core/domain/service_test.go new file mode 100644 index 000000000..01e73f8a8 --- /dev/null +++ b/core/domain/service_test.go @@ -0,0 +1,87 @@ +package domain_test + +import ( + "context" + "log/slog" + "testing" + + "github.com/raystack/frontier/core/authenticate" + "github.com/raystack/frontier/core/domain" + "github.com/raystack/frontier/core/domain/mocks" + "github.com/raystack/frontier/core/user" + "github.com/raystack/frontier/internal/bootstrap/schema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestService_ListJoinableOrgsByDomain(t *testing.T) { + ctx := context.Background() + email := "alice@example.com" + userID := "user-1" + + newService := func(t *testing.T) (*domain.Service, *mocks.Repository, *mocks.UserService, *mocks.MembershipService) { + t.Helper() + repo := mocks.NewRepository(t) + userSvc := mocks.NewUserService(t) + orgSvc := mocks.NewOrgService(t) + memberSvc := mocks.NewMembershipService(t) + svc := domain.NewService(slog.Default(), repo, userSvc, orgSvc, memberSvc) + return svc, repo, userSvc, memberSvc + } + + t.Run("returns verified-domain orgs the user is not a member of", func(t *testing.T) { + svc, repo, userSvc, memberSvc := newService(t) + + repo.EXPECT().List(ctx, domain.Filter{Name: "example.com", State: domain.Verified}). + Return([]domain.Domain{ + {OrgID: "org-1"}, + {OrgID: "org-2"}, + {OrgID: "org-3"}, + }, nil) + + userSvc.EXPECT().GetByID(ctx, email).Return(user.User{ID: userID}, nil) + + memberSvc.EXPECT().ListResourcesByPrincipal(ctx, authenticate.Principal{ + ID: userID, Type: schema.UserPrincipal, + }, schema.OrganizationNamespace, mock.Anything).Return([]string{"org-2"}, nil) + + got, err := svc.ListJoinableOrgsByDomain(ctx, email) + assert.NoError(t, err) + assert.Equal(t, []string{"org-1", "org-3"}, got) + }) + + t.Run("excludes disabled-org policy-holder from joinable list", func(t *testing.T) { + // A stale policy on a disabled org still counts as membership — + // otherwise we'd offer the disabled org as joinable. + svc, repo, userSvc, memberSvc := newService(t) + + repo.EXPECT().List(ctx, domain.Filter{Name: "example.com", State: domain.Verified}). + Return([]domain.Domain{{OrgID: "org-disabled"}}, nil) + + userSvc.EXPECT().GetByID(ctx, email).Return(user.User{ID: userID}, nil) + + // Membership returns the disabled org because it's policy-based, not state-aware. + memberSvc.EXPECT().ListResourcesByPrincipal(ctx, mock.Anything, schema.OrganizationNamespace, mock.Anything). + Return([]string{"org-disabled"}, nil) + + got, err := svc.ListJoinableOrgsByDomain(ctx, email) + assert.NoError(t, err) + assert.Empty(t, got) + }) + + t.Run("returns all verified-domain orgs when user has no memberships", func(t *testing.T) { + svc, repo, userSvc, memberSvc := newService(t) + + repo.EXPECT().List(ctx, domain.Filter{Name: "example.com", State: domain.Verified}). + Return([]domain.Domain{{OrgID: "org-1"}, {OrgID: "org-2"}}, nil) + + userSvc.EXPECT().GetByID(ctx, email).Return(user.User{ID: userID}, nil) + + memberSvc.EXPECT().ListResourcesByPrincipal(ctx, mock.Anything, schema.OrganizationNamespace, mock.Anything). + Return(nil, nil) + + got, err := svc.ListJoinableOrgsByDomain(ctx, email) + assert.NoError(t, err) + assert.Equal(t, []string{"org-1", "org-2"}, got) + }) +} diff --git a/core/invitation/mocks/membership_service.go b/core/invitation/mocks/membership_service.go index 189b7ae63..cb02a4c6c 100644 --- a/core/invitation/mocks/membership_service.go +++ b/core/invitation/mocks/membership_service.go @@ -5,6 +5,10 @@ package mocks import ( context "context" + authenticate "github.com/raystack/frontier/core/authenticate" + + membership "github.com/raystack/frontier/core/membership" + mock "github.com/stretchr/testify/mock" ) @@ -71,6 +75,67 @@ func (_c *MembershipService_AddOrganizationMember_Call) RunAndReturn(run func(co return _c } +// ListResourcesByPrincipal provides a mock function with given fields: ctx, principal, resourceType, filter +func (_m *MembershipService) ListResourcesByPrincipal(ctx context.Context, principal authenticate.Principal, resourceType string, filter membership.ResourceFilter) ([]string, error) { + ret := _m.Called(ctx, principal, resourceType, filter) + + if len(ret) == 0 { + panic("no return value specified for ListResourcesByPrincipal") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string, membership.ResourceFilter) ([]string, error)); ok { + return rf(ctx, principal, resourceType, filter) + } + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string, membership.ResourceFilter) []string); ok { + r0 = rf(ctx, principal, resourceType, filter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, string, membership.ResourceFilter) error); ok { + r1 = rf(ctx, principal, resourceType, filter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MembershipService_ListResourcesByPrincipal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListResourcesByPrincipal' +type MembershipService_ListResourcesByPrincipal_Call struct { + *mock.Call +} + +// ListResourcesByPrincipal is a helper method to define mock.On call +// - ctx context.Context +// - principal authenticate.Principal +// - resourceType string +// - filter membership.ResourceFilter +func (_e *MembershipService_Expecter) ListResourcesByPrincipal(ctx interface{}, principal interface{}, resourceType interface{}, filter interface{}) *MembershipService_ListResourcesByPrincipal_Call { + return &MembershipService_ListResourcesByPrincipal_Call{Call: _e.mock.On("ListResourcesByPrincipal", ctx, principal, resourceType, filter)} +} + +func (_c *MembershipService_ListResourcesByPrincipal_Call) Run(run func(ctx context.Context, principal authenticate.Principal, resourceType string, filter membership.ResourceFilter)) *MembershipService_ListResourcesByPrincipal_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(string), args[3].(membership.ResourceFilter)) + }) + return _c +} + +func (_c *MembershipService_ListResourcesByPrincipal_Call) Return(_a0 []string, _a1 error) *MembershipService_ListResourcesByPrincipal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MembershipService_ListResourcesByPrincipal_Call) RunAndReturn(run func(context.Context, authenticate.Principal, string, membership.ResourceFilter) ([]string, error)) *MembershipService_ListResourcesByPrincipal_Call { + _c.Call.Return(run) + return _c +} + // SetGroupMemberRole provides a mock function with given fields: ctx, groupID, principalID, principalType, roleID func (_m *MembershipService) SetGroupMemberRole(ctx context.Context, groupID string, principalID string, principalType string, roleID string) error { ret := _m.Called(ctx, groupID, principalID, principalType, roleID) diff --git a/core/invitation/mocks/organization_service.go b/core/invitation/mocks/organization_service.go index a7fd17551..9f45828c9 100644 --- a/core/invitation/mocks/organization_service.go +++ b/core/invitation/mocks/organization_service.go @@ -5,8 +5,6 @@ package mocks import ( context "context" - authenticate "github.com/raystack/frontier/core/authenticate" - mock "github.com/stretchr/testify/mock" organization "github.com/raystack/frontier/core/organization" @@ -82,66 +80,6 @@ func (_c *OrganizationService_Get_Call) RunAndReturn(run func(context.Context, s return _c } -// ListByUser provides a mock function with given fields: ctx, p, f -func (_m *OrganizationService) ListByUser(ctx context.Context, p authenticate.Principal, f organization.Filter) ([]organization.Organization, error) { - ret := _m.Called(ctx, p, f) - - if len(ret) == 0 { - panic("no return value specified for ListByUser") - } - - var r0 []organization.Organization - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, organization.Filter) ([]organization.Organization, error)); ok { - return rf(ctx, p, f) - } - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, organization.Filter) []organization.Organization); ok { - r0 = rf(ctx, p, f) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]organization.Organization) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, organization.Filter) error); ok { - r1 = rf(ctx, p, f) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// OrganizationService_ListByUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListByUser' -type OrganizationService_ListByUser_Call struct { - *mock.Call -} - -// ListByUser is a helper method to define mock.On call -// - ctx context.Context -// - p authenticate.Principal -// - f organization.Filter -func (_e *OrganizationService_Expecter) ListByUser(ctx interface{}, p interface{}, f interface{}) *OrganizationService_ListByUser_Call { - return &OrganizationService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, p, f)} -} - -func (_c *OrganizationService_ListByUser_Call) Run(run func(ctx context.Context, p authenticate.Principal, f organization.Filter)) *OrganizationService_ListByUser_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(organization.Filter)) - }) - return _c -} - -func (_c *OrganizationService_ListByUser_Call) Return(_a0 []organization.Organization, _a1 error) *OrganizationService_ListByUser_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *OrganizationService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, organization.Filter) ([]organization.Organization, error)) *OrganizationService_ListByUser_Call { - _c.Call.Return(run) - return _c -} - // NewOrganizationService creates a new instance of OrganizationService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewOrganizationService(t interface { diff --git a/core/invitation/service.go b/core/invitation/service.go index 9ff56ae11..144c2644c 100644 --- a/core/invitation/service.go +++ b/core/invitation/service.go @@ -44,12 +44,12 @@ type UserService interface { type OrganizationService interface { Get(ctx context.Context, id string) (organization.Organization, error) - ListByUser(ctx context.Context, p authenticate.Principal, f organization.Filter) ([]organization.Organization, error) } type MembershipService interface { AddOrganizationMember(ctx context.Context, orgID, principalID, principalType, roleID string) error SetGroupMemberRole(ctx context.Context, groupID, principalID, principalType, roleID string) error + ListResourcesByPrincipal(ctx context.Context, principal authenticate.Principal, resourceType string, filter membership.ResourceFilter) ([]string, error) } type GroupService interface { @@ -270,15 +270,15 @@ func (s Service) isUserOrgMember(ctx context.Context, orgID, userID string) (use return userOb, false, err } - orgs, err := s.orgSvc.ListByUser(ctx, authenticate.Principal{ + orgIDs, err := s.membershipSvc.ListResourcesByPrincipal(ctx, authenticate.Principal{ ID: userOb.ID, Type: schema.UserPrincipal, - }, organization.Filter{}) + }, schema.OrganizationNamespace, membership.ResourceFilter{}) if err != nil { return userOb, false, err } - for _, org := range orgs { - if org.ID == orgID { + for _, id := range orgIDs { + if id == orgID { return userOb, true, nil } } diff --git a/core/invitation/service_test.go b/core/invitation/service_test.go index b44625a02..ef9a8a79f 100644 --- a/core/invitation/service_test.go +++ b/core/invitation/service_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/raystack/frontier/core/invitation" "github.com/raystack/frontier/core/invitation/mocks" + "github.com/raystack/frontier/core/membership" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/pkg/errors" @@ -55,14 +56,6 @@ func TestService_Create(t *testing.T) { orgService.EXPECT().Get(mock.Anything, "org-id").Return(organization.Organization{ ID: "org-id", }, nil) - orgService.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ - ID: "user-id", - Type: schema.UserPrincipal, - }, organization.Filter{}).Return([]organization.Organization{ - { - ID: "org-id", - }, - }, nil) userService.EXPECT().GetByID(context.Background(), "test@example.com").Return(user.User{ ID: "user-id", @@ -70,6 +63,10 @@ func TestService_Create(t *testing.T) { }, nil) membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListResourcesByPrincipal(mock.Anything, authenticate.Principal{ + ID: "user-id", + Type: schema.UserPrincipal, + }, schema.OrganizationNamespace, membership.ResourceFilter{}).Return([]string{"org-id"}, nil) return invitation.NewService(dialer, repo, orgService, groupService, userService, relationService, prefService, auditRecordRepo, membershipSvc) }, diff --git a/core/membership/service.go b/core/membership/service.go index 76a4c1ba6..dee69d28c 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -375,23 +375,10 @@ func (s *Service) cascadeRemovePrincipal(ctx context.Context, org organization.O // clean up SpiceDB relations for _, g := range orgGroups { if err := s.removeRelations(ctx, g.ID, schema.GroupNamespace, principalID, principalType); err != nil { - s.log.Error("partial failure removing member: group relation cleanup failed, manual cleanup may be needed", - "org_id", orgID, - "group_id", g.ID, - "principal_id", principalID, - "principal_type", principalType, - "error", err, - ) return fmt.Errorf("remove group %s relations: %w", g.ID, err) } } if err := s.removeRelations(ctx, orgID, schema.OrganizationNamespace, principalID, principalType); err != nil { - s.log.Error("partial failure removing member: org relation cleanup failed, manual cleanup may be needed", - "org_id", orgID, - "principal_id", principalID, - "principal_type", principalType, - "error", err, - ) return fmt.Errorf("remove org relations: %w", err) } @@ -1574,6 +1561,13 @@ type ResourceFilter struct { NonInherited bool } +// ListOrgsByPrincipal lets the organization package consume this without +// importing membership — that direction would be a cycle since membership +// already imports organization. +func (s *Service) ListOrgsByPrincipal(ctx context.Context, principal authenticate.Principal) ([]string, error) { + return s.ListResourcesByPrincipal(ctx, principal, schema.OrganizationNamespace, ResourceFilter{}) +} + // ListResourcesByPrincipal returns the resource IDs of the given type on which // the principal has at least one policy. Reads Postgres policies — no SpiceDB. // With a PAT, runs the algorithm twice (user, then PAT-as-principal) and diff --git a/core/organization/filter.go b/core/organization/filter.go index eaf2a2b12..3e328be53 100644 --- a/core/organization/filter.go +++ b/core/organization/filter.go @@ -1,11 +1,14 @@ package organization import ( + "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/pkg/pagination" ) type Filter struct { - UserID string + // Principal restricts results to orgs the principal has a policy on. + // Intersected with IDs when both are set. + Principal *authenticate.Principal IDs []string State State diff --git a/core/organization/mocks/membership_service.go b/core/organization/mocks/membership_service.go index 7a81278b0..738279486 100644 --- a/core/organization/mocks/membership_service.go +++ b/core/organization/mocks/membership_service.go @@ -5,6 +5,8 @@ package mocks import ( context "context" + authenticate "github.com/raystack/frontier/core/authenticate" + mock "github.com/stretchr/testify/mock" ) @@ -71,6 +73,65 @@ func (_c *MembershipService_AddOrganizationMember_Call) RunAndReturn(run func(co return _c } +// ListOrgsByPrincipal provides a mock function with given fields: ctx, principal +func (_m *MembershipService) ListOrgsByPrincipal(ctx context.Context, principal authenticate.Principal) ([]string, error) { + ret := _m.Called(ctx, principal) + + if len(ret) == 0 { + panic("no return value specified for ListOrgsByPrincipal") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal) ([]string, error)); ok { + return rf(ctx, principal) + } + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal) []string); ok { + r0 = rf(ctx, principal) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal) error); ok { + r1 = rf(ctx, principal) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MembershipService_ListOrgsByPrincipal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListOrgsByPrincipal' +type MembershipService_ListOrgsByPrincipal_Call struct { + *mock.Call +} + +// ListOrgsByPrincipal is a helper method to define mock.On call +// - ctx context.Context +// - principal authenticate.Principal +func (_e *MembershipService_Expecter) ListOrgsByPrincipal(ctx interface{}, principal interface{}) *MembershipService_ListOrgsByPrincipal_Call { + return &MembershipService_ListOrgsByPrincipal_Call{Call: _e.mock.On("ListOrgsByPrincipal", ctx, principal)} +} + +func (_c *MembershipService_ListOrgsByPrincipal_Call) Run(run func(ctx context.Context, principal authenticate.Principal)) *MembershipService_ListOrgsByPrincipal_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authenticate.Principal)) + }) + return _c +} + +func (_c *MembershipService_ListOrgsByPrincipal_Call) Return(_a0 []string, _a1 error) *MembershipService_ListOrgsByPrincipal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MembershipService_ListOrgsByPrincipal_Call) RunAndReturn(run func(context.Context, authenticate.Principal) ([]string, error)) *MembershipService_ListOrgsByPrincipal_Call { + _c.Call.Return(run) + return _c +} + // NewMembershipService creates a new instance of MembershipService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMembershipService(t interface { diff --git a/core/organization/mocks/relation_service.go b/core/organization/mocks/relation_service.go index f1ed79397..c00e2d9a2 100644 --- a/core/organization/mocks/relation_service.go +++ b/core/organization/mocks/relation_service.go @@ -127,65 +127,6 @@ func (_c *RelationService_Delete_Call) RunAndReturn(run func(context.Context, re return _c } -// LookupResources provides a mock function with given fields: ctx, rel -func (_m *RelationService) LookupResources(ctx context.Context, rel relation.Relation) ([]string, error) { - ret := _m.Called(ctx, rel) - - if len(ret) == 0 { - panic("no return value specified for LookupResources") - } - - var r0 []string - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) ([]string, error)); ok { - return rf(ctx, rel) - } - if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) []string); ok { - r0 = rf(ctx, rel) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, relation.Relation) error); ok { - r1 = rf(ctx, rel) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// RelationService_LookupResources_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LookupResources' -type RelationService_LookupResources_Call struct { - *mock.Call -} - -// LookupResources is a helper method to define mock.On call -// - ctx context.Context -// - rel relation.Relation -func (_e *RelationService_Expecter) LookupResources(ctx interface{}, rel interface{}) *RelationService_LookupResources_Call { - return &RelationService_LookupResources_Call{Call: _e.mock.On("LookupResources", ctx, rel)} -} - -func (_c *RelationService_LookupResources_Call) Run(run func(ctx context.Context, rel relation.Relation)) *RelationService_LookupResources_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(relation.Relation)) - }) - return _c -} - -func (_c *RelationService_LookupResources_Call) Return(_a0 []string, _a1 error) *RelationService_LookupResources_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *RelationService_LookupResources_Call) RunAndReturn(run func(context.Context, relation.Relation) ([]string, error)) *RelationService_LookupResources_Call { - _c.Call.Return(run) - return _c -} - // NewRelationService creates a new instance of RelationService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewRelationService(t interface { diff --git a/core/organization/service.go b/core/organization/service.go index b09be7207..8fbbf165a 100644 --- a/core/organization/service.go +++ b/core/organization/service.go @@ -40,7 +40,6 @@ type Repository interface { type RelationService interface { Create(ctx context.Context, rel relation.Relation) (relation.Relation, error) - LookupResources(ctx context.Context, rel relation.Relation) ([]string, error) Delete(ctx context.Context, rel relation.Relation) error } @@ -75,6 +74,7 @@ type RoleService interface { type MembershipService interface { AddOrganizationMember(ctx context.Context, orgID, principalID, principalType, roleID string) error + ListOrgsByPrincipal(ctx context.Context, principal authenticate.Principal) ([]string, error) } type Service struct { @@ -246,13 +246,25 @@ func (s Service) AttachToPlatform(ctx context.Context, orgID string) error { } func (s Service) List(ctx context.Context, f Filter) ([]Organization, error) { - if f.UserID != "" { - return s.ListByUser(ctx, authenticate.Principal{ - ID: f.UserID, - Type: schema.UserPrincipal, - }, f) + if metrics.ServiceOprLatency != nil { + defer metrics.ServiceOprLatency("organization", "List")() + } + if f.Principal != nil { + if s.membershipService == nil { + return nil, fmt.Errorf("organization: membership service not wired") + } + orgIDs, err := s.membershipService.ListOrgsByPrincipal(ctx, *f.Principal) + if err != nil { + return nil, err + } + if len(f.IDs) > 0 { + orgIDs = utils.Intersection(orgIDs, f.IDs) + } + if len(orgIDs) == 0 { + return []Organization{}, nil + } + f.IDs = orgIDs } - // state gets filtered in db return s.repository.List(ctx, f) } @@ -264,40 +276,6 @@ func (s Service) Update(ctx context.Context, org Organization) (Organization, er return s.repository.UpdateByName(ctx, org) } -func (s Service) ListByUser(ctx context.Context, principal authenticate.Principal, filter Filter) ([]Organization, error) { - if metrics.ServiceOprLatency != nil { - promCollect := metrics.ServiceOprLatency("organization", "ListByUser") - defer promCollect() - } - - subjectID, subjectType := principal.ResolveSubject() - subjectIDs, err := s.relationService.LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.OrganizationNamespace, - }, - Subject: relation.Subject{ - ID: subjectID, - Namespace: subjectType, - }, - RelationName: schema.MembershipPermission, - }) - if err != nil { - return nil, err - } - - if principal.PAT != nil { - subjectIDs = utils.Intersection(subjectIDs, []string{principal.PAT.OrgID}) - } - - if len(subjectIDs) == 0 { - // no organizations - return []Organization{}, nil - } - - filter.IDs = subjectIDs - return s.repository.List(ctx, filter) -} - // RemoveUsers removes users from an organization as members // it doesn't remove user access to projects or other resources provided // by policies, don't call directly, use cascade deleter diff --git a/core/organization/service_test.go b/core/organization/service_test.go index 89d484208..f9d0fbd1a 100644 --- a/core/organization/service_test.go +++ b/core/organization/service_test.go @@ -6,12 +6,10 @@ import ( "testing" "github.com/google/uuid" - "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/organization/mocks" "github.com/raystack/frontier/core/preference" "github.com/raystack/frontier/core/relation" - pat "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -206,10 +204,10 @@ func TestService_AttachToPlatform(t *testing.T) { }) } -func TestService_ListByUser(t *testing.T) { +func TestService_List(t *testing.T) { ctx := context.Background() - t.Run("should resolve PAT to user and intersect with PAT org", func(t *testing.T) { + newService := func() (*organization.Service, *mocks.Repository) { mockRepo := mocks.NewRepository(t) mockRelationSvc := mocks.NewRelationService(t) mockUserSvc := mocks.NewUserService(t) @@ -217,94 +215,83 @@ func TestService_ListByUser(t *testing.T) { mockPolicySvc := mocks.NewPolicyService(t) mockPrefSvc := mocks.NewPreferencesService(t) mockAuditRecordRepo := mocks.NewAuditRecordRepository(t) - mockRoleSvc := mocks.NewRoleService(t) - svc := organization.NewService(mockRepo, mockRelationSvc, mockUserSvc, mockAuthnSvc, mockPolicySvc, mockPrefSvc, mockAuditRecordRepo, mockRoleSvc) - - // LookupResources should be called with user ID/type, not PAT - mockRelationSvc.On("LookupResources", ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.OrganizationNamespace}, - Subject: relation.Subject{ID: "user-123", Namespace: schema.UserPrincipal}, - RelationName: schema.MembershipPermission, - }).Return([]string{"org-1", "org-2"}, nil).Once() - - // Repo should only be called with the PAT's org (intersection result) - mockRepo.On("List", ctx, organization.Filter{ - IDs: []string{"org-1"}, - }).Return([]organization.Organization{ - {ID: "org-1", Name: "org-one"}, - }, nil).Once() - - result, err := svc.ListByUser(ctx, authenticate.Principal{ - ID: "pat-456", - Type: schema.PATPrincipal, - PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-1"}, - }, organization.Filter{}) + svc := organization.NewService(mockRepo, mockRelationSvc, mockUserSvc, mockAuthnSvc, + mockPolicySvc, mockPrefSvc, mockAuditRecordRepo, mockRoleSvc) + return svc, mockRepo + } + t.Run("passes empty filter to repository unchanged", func(t *testing.T) { + svc, mockRepo := newService() + mockRepo.On("List", ctx, organization.Filter{}). + Return([]organization.Organization{ + {ID: "org-1", Name: "org-one"}, + {ID: "org-2", Name: "org-two"}, + }, nil).Once() + + got, err := svc.List(ctx, organization.Filter{}) assert.NoError(t, err) - assert.Len(t, result, 1) - assert.Equal(t, "org-1", result[0].ID) + assert.Len(t, got, 2) }) - t.Run("should return empty when PAT org not in user memberships", func(t *testing.T) { - mockRepo := mocks.NewRepository(t) - mockRelationSvc := mocks.NewRelationService(t) - mockUserSvc := mocks.NewUserService(t) - mockAuthnSvc := mocks.NewAuthnService(t) - mockPolicySvc := mocks.NewPolicyService(t) - mockPrefSvc := mocks.NewPreferencesService(t) - mockAuditRecordRepo := mocks.NewAuditRecordRepository(t) + t.Run("forwards IDs filter to repository", func(t *testing.T) { + svc, mockRepo := newService() + ids := []string{"org-1", "org-2"} + mockRepo.On("List", ctx, organization.Filter{IDs: ids}). + Return([]organization.Organization{ + {ID: "org-1", Name: "org-one"}, + {ID: "org-2", Name: "org-two"}, + }, nil).Once() - mockRoleSvc := mocks.NewRoleService(t) - svc := organization.NewService(mockRepo, mockRelationSvc, mockUserSvc, mockAuthnSvc, mockPolicySvc, mockPrefSvc, mockAuditRecordRepo, mockRoleSvc) + got, err := svc.List(ctx, organization.Filter{IDs: ids}) + assert.NoError(t, err) + assert.Len(t, got, 2) + }) - mockRelationSvc.On("LookupResources", ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.OrganizationNamespace}, - Subject: relation.Subject{ID: "user-123", Namespace: schema.UserPrincipal}, - RelationName: schema.MembershipPermission, - }).Return([]string{"org-1", "org-2"}, nil).Once() + t.Run("forwards state filter to repository", func(t *testing.T) { + svc, mockRepo := newService() + mockRepo.On("List", ctx, organization.Filter{State: organization.Disabled}). + Return([]organization.Organization{ + {ID: "org-3", Name: "org-three", State: organization.Disabled}, + }, nil).Once() - result, err := svc.ListByUser(ctx, authenticate.Principal{ - ID: "pat-456", - Type: schema.PATPrincipal, - PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-999"}, - }, organization.Filter{}) + got, err := svc.List(ctx, organization.Filter{State: organization.Disabled}) + assert.NoError(t, err) + assert.Len(t, got, 1) + assert.Equal(t, organization.Disabled, got[0].State) + }) + t.Run("forwards combined IDs and state filter to repository", func(t *testing.T) { + svc, mockRepo := newService() + f := organization.Filter{IDs: []string{"org-1"}, State: organization.Enabled} + mockRepo.On("List", ctx, f). + Return([]organization.Organization{ + {ID: "org-1", Name: "org-one", State: organization.Enabled}, + }, nil).Once() + + got, err := svc.List(ctx, f) assert.NoError(t, err) - assert.Empty(t, result) + assert.Len(t, got, 1) }) - t.Run("should pass through for regular user principal", func(t *testing.T) { - mockRepo := mocks.NewRepository(t) - mockRelationSvc := mocks.NewRelationService(t) - mockUserSvc := mocks.NewUserService(t) - mockAuthnSvc := mocks.NewAuthnService(t) - mockPolicySvc := mocks.NewPolicyService(t) - mockPrefSvc := mocks.NewPreferencesService(t) - mockAuditRecordRepo := mocks.NewAuditRecordRepository(t) + t.Run("propagates repository errors unchanged", func(t *testing.T) { + svc, mockRepo := newService() + repoErr := errors.New("db down") + mockRepo.On("List", ctx, organization.Filter{}). + Return(nil, repoErr).Once() - mockRoleSvc := mocks.NewRoleService(t) - svc := organization.NewService(mockRepo, mockRelationSvc, mockUserSvc, mockAuthnSvc, mockPolicySvc, mockPrefSvc, mockAuditRecordRepo, mockRoleSvc) - - mockRelationSvc.On("LookupResources", ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.OrganizationNamespace}, - Subject: relation.Subject{ID: "user-123", Namespace: schema.UserPrincipal}, - RelationName: schema.MembershipPermission, - }).Return([]string{"org-1", "org-2"}, nil).Once() - - mockRepo.On("List", ctx, organization.Filter{ - IDs: []string{"org-1", "org-2"}, - }).Return([]organization.Organization{ - {ID: "org-1", Name: "org-one"}, - {ID: "org-2", Name: "org-two"}, - }, nil).Once() - - result, err := svc.ListByUser(ctx, authenticate.Principal{ - ID: "user-123", - Type: schema.UserPrincipal, - }, organization.Filter{}) + got, err := svc.List(ctx, organization.Filter{}) + assert.ErrorIs(t, err, repoErr) + assert.Nil(t, got) + }) + + t.Run("returns empty slice when repository returns no rows", func(t *testing.T) { + svc, mockRepo := newService() + mockRepo.On("List", ctx, organization.Filter{IDs: []string{"org-nope"}}). + Return([]organization.Organization{}, nil).Once() + got, err := svc.List(ctx, organization.Filter{IDs: []string{"org-nope"}}) assert.NoError(t, err) - assert.Len(t, result, 2) + assert.Empty(t, got) }) } diff --git a/internal/api/v1beta1connect/authenticate.go b/internal/api/v1beta1connect/authenticate.go index 83c83058d..f55761d65 100644 --- a/internal/api/v1beta1connect/authenticate.go +++ b/internal/api/v1beta1connect/authenticate.go @@ -224,17 +224,18 @@ func (h *ConnectHandler) getAccessToken(ctx context.Context, principal authentic customClaims := map[string]string{} if h.authConfig.Token.Claims.AddOrgIDsClaim { - // get orgs a user belongs to - orgs, err := h.orgService.ListByUser(ctx, principal, organization.Filter{}) + orgs, err := h.orgService.List(ctx, organization.Filter{ + Principal: &principal, + State: organization.Enabled, + }) if err != nil { return nil, err } - - var orgIds []string + orgIDs := make([]string, 0, len(orgs)) for _, o := range orgs { - orgIds = append(orgIds, o.ID) + orgIDs = append(orgIDs, o.ID) } - customClaims[token.OrgIDsClaimKey] = strings.Join(orgIds, ",") + customClaims[token.OrgIDsClaimKey] = strings.Join(orgIDs, ",") } // add session ID as claims for upstream diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index 077569cb4..90630da99 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -124,7 +124,6 @@ type OrganizationService interface { AdminCreate(ctx context.Context, org organization.Organization, ownerEmail string) (organization.Organization, error) List(ctx context.Context, f organization.Filter) ([]organization.Organization, error) Update(ctx context.Context, toUpdate organization.Organization) (organization.Organization, error) - ListByUser(ctx context.Context, principal authenticate.Principal, flt organization.Filter) ([]organization.Organization, error) Enable(ctx context.Context, id string) error Disable(ctx context.Context, id string) error } diff --git a/internal/api/v1beta1connect/mocks/organization_service.go b/internal/api/v1beta1connect/mocks/organization_service.go index fac005cc6..24352891f 100644 --- a/internal/api/v1beta1connect/mocks/organization_service.go +++ b/internal/api/v1beta1connect/mocks/organization_service.go @@ -5,11 +5,8 @@ package mocks import ( context "context" - authenticate "github.com/raystack/frontier/core/authenticate" - - mock "github.com/stretchr/testify/mock" - organization "github.com/raystack/frontier/core/organization" + mock "github.com/stretchr/testify/mock" ) // OrganizationService is an autogenerated mock type for the OrganizationService type @@ -407,66 +404,6 @@ func (_c *OrganizationService_List_Call) RunAndReturn(run func(context.Context, return _c } -// ListByUser provides a mock function with given fields: ctx, principal, flt -func (_m *OrganizationService) ListByUser(ctx context.Context, principal authenticate.Principal, flt organization.Filter) ([]organization.Organization, error) { - ret := _m.Called(ctx, principal, flt) - - if len(ret) == 0 { - panic("no return value specified for ListByUser") - } - - var r0 []organization.Organization - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, organization.Filter) ([]organization.Organization, error)); ok { - return rf(ctx, principal, flt) - } - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, organization.Filter) []organization.Organization); ok { - r0 = rf(ctx, principal, flt) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]organization.Organization) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, organization.Filter) error); ok { - r1 = rf(ctx, principal, flt) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// OrganizationService_ListByUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListByUser' -type OrganizationService_ListByUser_Call struct { - *mock.Call -} - -// ListByUser is a helper method to define mock.On call -// - ctx context.Context -// - principal authenticate.Principal -// - flt organization.Filter -func (_e *OrganizationService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *OrganizationService_ListByUser_Call { - return &OrganizationService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} -} - -func (_c *OrganizationService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt organization.Filter)) *OrganizationService_ListByUser_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(organization.Filter)) - }) - return _c -} - -func (_c *OrganizationService_ListByUser_Call) Return(_a0 []organization.Organization, _a1 error) *OrganizationService_ListByUser_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *OrganizationService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, organization.Filter) ([]organization.Organization, error)) *OrganizationService_ListByUser_Call { - _c.Call.Return(run) - return _c -} - // Update provides a mock function with given fields: ctx, toUpdate func (_m *OrganizationService) Update(ctx context.Context, toUpdate organization.Organization) (organization.Organization, error) { ret := _m.Called(ctx, toUpdate) diff --git a/internal/api/v1beta1connect/organization.go b/internal/api/v1beta1connect/organization.go index 3907baa29..abeb3c91b 100644 --- a/internal/api/v1beta1connect/organization.go +++ b/internal/api/v1beta1connect/organization.go @@ -8,6 +8,7 @@ import ( "log/slog" "github.com/raystack/frontier/core/audit" + "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/membership" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/project" @@ -57,70 +58,54 @@ func (h *ConnectHandler) GetOrganization(ctx context.Context, request *connect.R } func (h *ConnectHandler) ListOrganizations(ctx context.Context, request *connect.Request[frontierv1beta1.ListOrganizationsRequest]) (*connect.Response[frontierv1beta1.ListOrganizationsResponse], error) { - errorLogger := NewErrorLogger() - - var orgs []*frontierv1beta1.Organization - paginate := pagination.NewPagination(request.Msg.GetPageNum(), request.Msg.GetPageSize()) - - orgList, err := h.orgService.List(ctx, organization.Filter{ - State: organization.State(request.Msg.GetState()), - UserID: request.Msg.GetUserId(), - Pagination: paginate, - }) + orgs, _, err := h.searchOrgs(ctx, request, request.Msg.GetUserId(), request.Msg.GetState(), request.Msg.GetPageNum(), request.Msg.GetPageSize(), "ListOrganizations") if err != nil { - errorLogger.LogServiceError(ctx, request, "ListOrganizations.List", err, - "state", request.Msg.GetState(), - "user_id", request.Msg.GetUserId()) - return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) - } - - for _, v := range orgList { - orgPB, err := transformOrgToPB(v) - if err != nil { - errorLogger.LogTransformError(ctx, request, "ListOrganizations", v.ID, err) - return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) - } - - orgs = append(orgs, orgPB) + return nil, err } - return connect.NewResponse(&frontierv1beta1.ListOrganizationsResponse{ Organizations: orgs, }), nil } func (h *ConnectHandler) ListAllOrganizations(ctx context.Context, request *connect.Request[frontierv1beta1.ListAllOrganizationsRequest]) (*connect.Response[frontierv1beta1.ListAllOrganizationsResponse], error) { - errorLogger := NewErrorLogger() - - var orgs []*frontierv1beta1.Organization - paginate := pagination.NewPagination(request.Msg.GetPageNum(), request.Msg.GetPageSize()) + orgs, count, err := h.searchOrgs(ctx, request, request.Msg.GetUserId(), request.Msg.GetState(), request.Msg.GetPageNum(), request.Msg.GetPageSize(), "ListAllOrganizations") + if err != nil { + return nil, err + } + return connect.NewResponse(&frontierv1beta1.ListAllOrganizationsResponse{ + Organizations: orgs, + Count: count, + }), nil +} - orgList, err := h.orgService.List(ctx, organization.Filter{ - State: organization.State(request.Msg.GetState()), - UserID: request.Msg.GetUserId(), +func (h *ConnectHandler) searchOrgs(ctx context.Context, req connect.AnyRequest, userID, stateStr string, pageNum, pageSize int32, rpcName string) ([]*frontierv1beta1.Organization, int32, error) { + errorLogger := NewErrorLogger() + paginate := pagination.NewPagination(pageNum, pageSize) + filter := organization.Filter{ + State: organization.State(stateStr), Pagination: paginate, - }) + } + if userID != "" { + filter.Principal = &authenticate.Principal{ID: userID, Type: schema.UserPrincipal} + } + + orgList, err := h.orgService.List(ctx, filter) if err != nil { - errorLogger.LogServiceError(ctx, request, "ListAllOrganizations.List", err, - "state", request.Msg.GetState(), - "user_id", request.Msg.GetUserId()) - return nil, err + errorLogger.LogServiceError(ctx, req, rpcName+".List", err, + "state", stateStr, "user_id", userID) + return nil, 0, connect.NewError(connect.CodeInternal, ErrInternalServerError) } + orgs := make([]*frontierv1beta1.Organization, 0, len(orgList)) for _, v := range orgList { orgPB, err := transformOrgToPB(v) if err != nil { - errorLogger.LogTransformError(ctx, request, "ListAllOrganizations", v.ID, err) - return nil, err + errorLogger.LogTransformError(ctx, req, rpcName, v.ID, err) + return nil, 0, connect.NewError(connect.CodeInternal, ErrInternalServerError) } - orgs = append(orgs, orgPB) } - - return connect.NewResponse(&frontierv1beta1.ListAllOrganizationsResponse{ - Organizations: orgs, - Count: paginate.Count, - }), nil + return orgs, paginate.Count, nil } func (h *ConnectHandler) CreateOrganization(ctx context.Context, request *connect.Request[frontierv1beta1.CreateOrganizationRequest]) (*connect.Response[frontierv1beta1.CreateOrganizationResponse], error) { diff --git a/internal/api/v1beta1connect/user.go b/internal/api/v1beta1connect/user.go index af6689b71..e779a2f11 100644 --- a/internal/api/v1beta1connect/user.go +++ b/internal/api/v1beta1connect/user.go @@ -728,15 +728,10 @@ func (h *ConnectHandler) ListOrganizationsByUser(ctx context.Context, request *c errorLogger := NewErrorLogger() userID := request.Msg.GetId() - orgFilter := organization.Filter{} - if request.Msg.GetState() != "" { - orgFilter.State = organization.State(request.Msg.GetState()) - } - - orgList, err := h.orgService.ListByUser(ctx, authenticate.Principal{ - ID: userID, - Type: schema.UserPrincipal, - }, orgFilter) + orgList, err := h.orgService.List(ctx, organization.Filter{ + Principal: &authenticate.Principal{ID: userID, Type: schema.UserPrincipal}, + State: organization.State(request.Msg.GetState()), + }) if err != nil { errorLogger.LogUnexpectedError(ctx, request, "ListOrganizationsByUser", err, "user_id", userID) @@ -808,12 +803,10 @@ func (h *ConnectHandler) ListOrganizationsByCurrentUser(ctx context.Context, req return nil, err } - orgFilter := organization.Filter{} - if request.Msg.GetState() != "" { - orgFilter.State = organization.State(request.Msg.GetState()) - } - - orgList, err := h.orgService.ListByUser(ctx, principal, orgFilter) + orgList, err := h.orgService.List(ctx, organization.Filter{ + Principal: &principal, + State: organization.State(request.Msg.GetState()), + }) if err != nil { errorLogger.LogUnexpectedError(ctx, request, "ListOrganizationsByCurrentUser", err, "principal_id", principal.ID, diff --git a/internal/api/v1beta1connect/user_test.go b/internal/api/v1beta1connect/user_test.go index f69671229..4196f3f77 100644 --- a/internal/api/v1beta1connect/user_test.go +++ b/internal/api/v1beta1connect/user_test.go @@ -1191,6 +1191,9 @@ func TestConnectHandler_ListCurrentUserGroups(t *testing.T) { func TestConnectHandler_ListOrganizationsByUser(t *testing.T) { userID := uuid.New().String() + userPrincipal := authenticate.Principal{ID: userID, Type: schema.UserPrincipal} + principalFilter := organization.Filter{Principal: &userPrincipal} + tests := []struct { title string setup func(*mocks.OrganizationService, *mocks.UserService, *mocks.DomainService) @@ -1201,18 +1204,16 @@ func TestConnectHandler_ListOrganizationsByUser(t *testing.T) { { title: "should list user organizations successfully", setup: func(os *mocks.OrganizationService, us *mocks.UserService, ds *mocks.DomainService) { - os.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ - ID: userID, - Type: schema.UserPrincipal, - }, organization.Filter{}).Return([]organization.Organization{ - { - ID: "org-1", - Name: "test-org-1", - Title: "Test Organization 1", - State: organization.Enabled, - Metadata: metadata.Metadata{}, - }, - }, nil) + os.EXPECT().List(mock.Anything, principalFilter). + Return([]organization.Organization{ + { + ID: "org-1", + Name: "test-org-1", + Title: "Test Organization 1", + State: organization.Enabled, + Metadata: metadata.Metadata{}, + }, + }, nil) us.EXPECT().GetByID(mock.Anything, userID).Return(user.User{ ID: userID, @@ -1254,10 +1255,8 @@ func TestConnectHandler_ListOrganizationsByUser(t *testing.T) { { title: "should return empty list when user has no organizations", setup: func(os *mocks.OrganizationService, us *mocks.UserService, ds *mocks.DomainService) { - os.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ - ID: userID, - Type: schema.UserPrincipal, - }, organization.Filter{}).Return([]organization.Organization{}, nil) + os.EXPECT().List(mock.Anything, principalFilter). + Return(nil, nil) us.EXPECT().GetByID(mock.Anything, userID).Return(user.User{ ID: userID, @@ -1279,10 +1278,8 @@ func TestConnectHandler_ListOrganizationsByUser(t *testing.T) { { title: "should return not found error for invalid user ID", setup: func(os *mocks.OrganizationService, us *mocks.UserService, ds *mocks.DomainService) { - os.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ - ID: userID, - Type: schema.UserPrincipal, - }, organization.Filter{}).Return([]organization.Organization{}, nil) + os.EXPECT().List(mock.Anything, principalFilter). + Return(nil, nil) us.EXPECT().GetByID(mock.Anything, userID).Return(user.User{}, user.ErrNotExist) }, @@ -1295,10 +1292,8 @@ func TestConnectHandler_ListOrganizationsByUser(t *testing.T) { { title: "should return internal error for service failure", setup: func(os *mocks.OrganizationService, us *mocks.UserService, ds *mocks.DomainService) { - os.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ - ID: userID, - Type: schema.UserPrincipal, - }, organization.Filter{}).Return(nil, errors.New("database error")) + os.EXPECT().List(mock.Anything, principalFilter). + Return(nil, errors.New("database error")) }, req: &frontierv1beta1.ListOrganizationsByUserRequest{ Id: userID, @@ -1359,6 +1354,17 @@ func TestConnectHandler_ListOrganizationsByUser(t *testing.T) { } func TestConnectHandler_ListOrganizationsByCurrentUser(t *testing.T) { + userPrincipal := authenticate.Principal{ + ID: "user-1", + Type: "app/user", + User: &user.User{ID: "user-1", Email: "test@example.com"}, + } + suPrincipal := authenticate.Principal{ + ID: "serviceuser-1", + Type: schema.ServiceUserPrincipal, + ServiceUser: &serviceuser.ServiceUser{ID: "serviceuser-1", OrgID: "org-1"}, + } + tests := []struct { title string setup func(*mocks.OrganizationService, *mocks.AuthnService, *mocks.DomainService) @@ -1369,22 +1375,18 @@ func TestConnectHandler_ListOrganizationsByCurrentUser(t *testing.T) { { title: "should list current user organizations successfully", setup: func(os *mocks.OrganizationService, as *mocks.AuthnService, ds *mocks.DomainService) { - mockPrincipal := authenticate.Principal{ - ID: "user-1", - Type: "app/user", - User: &user.User{ID: "user-1", Email: "test@example.com"}, - } - as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - - os.EXPECT().ListByUser(mock.Anything, mockPrincipal, organization.Filter{}).Return([]organization.Organization{ - { - ID: "org-1", - Name: "test-org-1", - Title: "Test Organization 1", - State: organization.Enabled, - Metadata: metadata.Metadata{}, - }, - }, nil) + as.EXPECT().GetPrincipal(mock.Anything).Return(userPrincipal, nil) + + os.EXPECT().List(mock.Anything, organization.Filter{Principal: &userPrincipal}). + Return([]organization.Organization{ + { + ID: "org-1", + Name: "test-org-1", + Title: "Test Organization 1", + State: organization.Enabled, + Metadata: metadata.Metadata{}, + }, + }, nil) ds.EXPECT().ListJoinableOrgsByDomain(mock.Anything, "test@example.com").Return([]string{"org-2"}, nil) @@ -1418,14 +1420,10 @@ func TestConnectHandler_ListOrganizationsByCurrentUser(t *testing.T) { { title: "should return empty list when current user has no organizations", setup: func(os *mocks.OrganizationService, as *mocks.AuthnService, ds *mocks.DomainService) { - mockPrincipal := authenticate.Principal{ - ID: "user-1", - Type: "app/user", - User: &user.User{ID: "user-1", Email: "test@example.com"}, - } - as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) + as.EXPECT().GetPrincipal(mock.Anything).Return(userPrincipal, nil) - os.EXPECT().ListByUser(mock.Anything, mockPrincipal, organization.Filter{}).Return([]organization.Organization{}, nil) + os.EXPECT().List(mock.Anything, organization.Filter{Principal: &userPrincipal}). + Return(nil, nil) ds.EXPECT().ListJoinableOrgsByDomain(mock.Anything, "test@example.com").Return([]string{}, nil) }, @@ -1439,23 +1437,18 @@ func TestConnectHandler_ListOrganizationsByCurrentUser(t *testing.T) { { title: "should handle service user without accessing user email", setup: func(os *mocks.OrganizationService, as *mocks.AuthnService, ds *mocks.DomainService) { - mockPrincipal := authenticate.Principal{ - ID: "serviceuser-1", - Type: schema.ServiceUserPrincipal, - ServiceUser: &serviceuser.ServiceUser{ID: "serviceuser-1", OrgID: "org-1"}, - User: nil, // Service users don't have a User object - } - as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - - os.EXPECT().ListByUser(mock.Anything, mockPrincipal, organization.Filter{}).Return([]organization.Organization{ - { - ID: "org-1", - Name: "service-org", - Title: "Service Organization", - State: organization.Enabled, - Metadata: metadata.Metadata{}, - }, - }, nil) + as.EXPECT().GetPrincipal(mock.Anything).Return(suPrincipal, nil) + + os.EXPECT().List(mock.Anything, organization.Filter{Principal: &suPrincipal}). + Return([]organization.Organization{ + { + ID: "org-1", + Name: "service-org", + Title: "Service Organization", + State: organization.Enabled, + Metadata: metadata.Metadata{}, + }, + }, nil) // No domain service call expected since service users can't join by domain }, req: &frontierv1beta1.ListOrganizationsByCurrentUserRequest{}, @@ -1481,16 +1474,12 @@ func TestConnectHandler_ListOrganizationsByCurrentUser(t *testing.T) { err: connect.CodeUnauthenticated, }, { - title: "should return internal error for organization service failure", + title: "should return internal error for service failure", setup: func(os *mocks.OrganizationService, as *mocks.AuthnService, ds *mocks.DomainService) { - mockPrincipal := authenticate.Principal{ - ID: "user-1", - Type: "app/user", - User: &user.User{ID: "user-1", Email: "test@example.com"}, - } - as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) + as.EXPECT().GetPrincipal(mock.Anything).Return(userPrincipal, nil) - os.EXPECT().ListByUser(mock.Anything, mockPrincipal, organization.Filter{}).Return(nil, errors.New("database error")) + os.EXPECT().List(mock.Anything, organization.Filter{Principal: &userPrincipal}). + Return(nil, errors.New("database error")) }, req: &frontierv1beta1.ListOrganizationsByCurrentUserRequest{}, want: nil,