diff --git a/database/mock/store.go b/database/mock/store.go index 4ee5419eb5..67c7ab7e7d 100644 --- a/database/mock/store.go +++ b/database/mock/store.go @@ -462,17 +462,17 @@ func (mr *MockStoreMockRecorder) CreateUser(ctx, identitySubject any) *gomock.Ca } // DeleteAllPropertiesForEntity mocks base method. -func (m *MockStore) DeleteAllPropertiesForEntity(ctx context.Context, entityID uuid.UUID) error { +func (m *MockStore) DeleteAllPropertiesForEntity(ctx context.Context, arg db.DeleteAllPropertiesForEntityParams) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAllPropertiesForEntity", ctx, entityID) + ret := m.ctrl.Call(m, "DeleteAllPropertiesForEntity", ctx, arg) ret0, _ := ret[0].(error) return ret0 } // DeleteAllPropertiesForEntity indicates an expected call of DeleteAllPropertiesForEntity. -func (mr *MockStoreMockRecorder) DeleteAllPropertiesForEntity(ctx, entityID any) *gomock.Call { +func (mr *MockStoreMockRecorder) DeleteAllPropertiesForEntity(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllPropertiesForEntity", reflect.TypeOf((*MockStore)(nil).DeleteAllPropertiesForEntity), ctx, entityID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllPropertiesForEntity", reflect.TypeOf((*MockStore)(nil).DeleteAllPropertiesForEntity), ctx, arg) } // DeleteDataSource mocks base method. @@ -837,18 +837,18 @@ func (mr *MockStoreMockRecorder) FindProviders(ctx, arg any) *gomock.Call { } // FlushCache mocks base method. -func (m *MockStore) FlushCache(ctx context.Context, entityInstanceID uuid.UUID) (db.FlushCache, error) { +func (m *MockStore) FlushCache(ctx context.Context, arg db.FlushCacheParams) (db.FlushCache, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FlushCache", ctx, entityInstanceID) + ret := m.ctrl.Call(m, "FlushCache", ctx, arg) ret0, _ := ret[0].(db.FlushCache) ret1, _ := ret[1].(error) return ret0, ret1 } // FlushCache indicates an expected call of FlushCache. -func (mr *MockStoreMockRecorder) FlushCache(ctx, entityInstanceID any) *gomock.Call { +func (mr *MockStoreMockRecorder) FlushCache(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushCache", reflect.TypeOf((*MockStore)(nil).FlushCache), ctx, entityInstanceID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushCache", reflect.TypeOf((*MockStore)(nil).FlushCache), ctx, arg) } // GetAccessTokenByEnrollmentNonce mocks base method. @@ -912,18 +912,18 @@ func (mr *MockStoreMockRecorder) GetAccessTokenSinceDate(ctx, arg any) *gomock.C } // GetAllPropertiesForEntity mocks base method. -func (m *MockStore) GetAllPropertiesForEntity(ctx context.Context, entityID uuid.UUID) ([]db.Property, error) { +func (m *MockStore) GetAllPropertiesForEntity(ctx context.Context, arg db.GetAllPropertiesForEntityParams) ([]db.Property, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAllPropertiesForEntity", ctx, entityID) + ret := m.ctrl.Call(m, "GetAllPropertiesForEntity", ctx, arg) ret0, _ := ret[0].([]db.Property) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAllPropertiesForEntity indicates an expected call of GetAllPropertiesForEntity. -func (mr *MockStoreMockRecorder) GetAllPropertiesForEntity(ctx, entityID any) *gomock.Call { +func (mr *MockStoreMockRecorder) GetAllPropertiesForEntity(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPropertiesForEntity", reflect.TypeOf((*MockStore)(nil).GetAllPropertiesForEntity), ctx, entityID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPropertiesForEntity", reflect.TypeOf((*MockStore)(nil).GetAllPropertiesForEntity), ctx, arg) } // GetBundle mocks base method. @@ -1002,18 +1002,18 @@ func (mr *MockStoreMockRecorder) GetEntitiesByProjectHierarchy(ctx, projects any } // GetEntitiesByProvider mocks base method. -func (m *MockStore) GetEntitiesByProvider(ctx context.Context, providerID uuid.UUID) ([]db.EntityInstance, error) { +func (m *MockStore) GetEntitiesByProvider(ctx context.Context, arg db.GetEntitiesByProviderParams) ([]db.EntityInstance, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEntitiesByProvider", ctx, providerID) + ret := m.ctrl.Call(m, "GetEntitiesByProvider", ctx, arg) ret0, _ := ret[0].([]db.EntityInstance) ret1, _ := ret[1].(error) return ret0, ret1 } // GetEntitiesByProvider indicates an expected call of GetEntitiesByProvider. -func (mr *MockStoreMockRecorder) GetEntitiesByProvider(ctx, providerID any) *gomock.Call { +func (mr *MockStoreMockRecorder) GetEntitiesByProvider(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntitiesByProvider", reflect.TypeOf((*MockStore)(nil).GetEntitiesByProvider), ctx, providerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntitiesByProvider", reflect.TypeOf((*MockStore)(nil).GetEntitiesByProvider), ctx, arg) } // GetEntitiesByType mocks base method. @@ -1047,18 +1047,18 @@ func (mr *MockStoreMockRecorder) GetEntitlementFeaturesByProjectID(ctx, projectI } // GetEntityByID mocks base method. -func (m *MockStore) GetEntityByID(ctx context.Context, id uuid.UUID) (db.EntityInstance, error) { +func (m *MockStore) GetEntityByID(ctx context.Context, arg db.GetEntityByIDParams) (db.EntityInstance, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEntityByID", ctx, id) + ret := m.ctrl.Call(m, "GetEntityByID", ctx, arg) ret0, _ := ret[0].(db.EntityInstance) ret1, _ := ret[1].(error) return ret0, ret1 } // GetEntityByID indicates an expected call of GetEntityByID. -func (mr *MockStoreMockRecorder) GetEntityByID(ctx, id any) *gomock.Call { +func (mr *MockStoreMockRecorder) GetEntityByID(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntityByID", reflect.TypeOf((*MockStore)(nil).GetEntityByID), ctx, id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntityByID", reflect.TypeOf((*MockStore)(nil).GetEntityByID), ctx, arg) } // GetEntityByName mocks base method. @@ -1436,6 +1436,21 @@ func (mr *MockStoreMockRecorder) GetProjectIDBySessionState(ctx, sessionState an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProjectIDBySessionState", reflect.TypeOf((*MockStore)(nil).GetProjectIDBySessionState), ctx, sessionState) } +// GetPropertiesForEntities mocks base method. +func (m *MockStore) GetPropertiesForEntities(ctx context.Context, arg db.GetPropertiesForEntitiesParams) ([]db.Property, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPropertiesForEntities", ctx, arg) + ret0, _ := ret[0].([]db.Property) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPropertiesForEntities indicates an expected call of GetPropertiesForEntities. +func (mr *MockStoreMockRecorder) GetPropertiesForEntities(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPropertiesForEntities", reflect.TypeOf((*MockStore)(nil).GetPropertiesForEntities), ctx, arg) +} + // GetProperty mocks base method. func (m *MockStore) GetProperty(ctx context.Context, arg db.GetPropertyParams) (db.Property, error) { m.ctrl.T.Helper() @@ -1944,10 +1959,10 @@ func (mr *MockStoreMockRecorder) ListEvaluationHistoryStaleRecords(ctx, arg any) } // ListFlushCache mocks base method. -func (m *MockStore) ListFlushCache(ctx context.Context) ([]db.FlushCache, error) { +func (m *MockStore) ListFlushCache(ctx context.Context) ([]db.ListFlushCacheRow, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ListFlushCache", ctx) - ret0, _ := ret[0].([]db.FlushCache) + ret0, _ := ret[0].([]db.ListFlushCacheRow) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/database/query/entities.sql b/database/query/entities.sql index 771a008251..24865c15c5 100644 --- a/database/query/entities.sql +++ b/database/query/entities.sql @@ -1,5 +1,4 @@ -- CreateEntity adds an entry to the entity_instances table so it can be tracked by Minder. - -- name: CreateEntity :one INSERT INTO entity_instances ( entity_type, @@ -7,11 +6,16 @@ INSERT INTO entity_instances ( project_id, provider_id, originated_from -) VALUES ($1, $2, sqlc.arg(project_id), sqlc.arg(provider_id), sqlc.narg(originated_from)) +) VALUES ( + sqlc.arg(entity_type), + sqlc.arg(name), + sqlc.arg(project_id), + sqlc.arg(provider_id), + sqlc.narg(originated_from) +) RETURNING *; -- CreateEntityWithID adds an entry to the entities table with a specific ID so it can be tracked by Minder. - -- name: CreateEntityWithID :one INSERT INTO entity_instances ( id, @@ -20,12 +24,17 @@ INSERT INTO entity_instances ( project_id, provider_id, originated_from -) VALUES ($1, $2, $3, sqlc.arg(project_id), sqlc.arg(provider_id), sqlc.narg(originated_from)) +) VALUES ( + sqlc.arg(id), + sqlc.arg(entity_type), + sqlc.arg(name), + sqlc.arg(project_id), + sqlc.arg(provider_id), + sqlc.narg(originated_from) +) RETURNING *; - -- CreateOrEnsureEntityByID adds an entry to the entity_instances table if it does not exist, or returns the existing entry. - -- name: CreateOrEnsureEntityByID :one INSERT INTO entity_instances ( id, @@ -34,120 +43,165 @@ INSERT INTO entity_instances ( project_id, provider_id, originated_from -) VALUES ($1, $2, $3, sqlc.arg(project_id), sqlc.arg(provider_id), sqlc.narg(originated_from)) +) VALUES ( + sqlc.arg(id), + sqlc.arg(entity_type), + sqlc.arg(name), + sqlc.arg(project_id), + sqlc.arg(provider_id), + sqlc.narg(originated_from) +) ON CONFLICT (id) DO UPDATE SET id = entity_instances.id -- This is a "noop" update to ensure the RETURNING clause works RETURNING *; --- DeleteEntity removes an entity from the entity_instances table for a project. +-- DeleteEntity removes an entity from the entity_instances table securely. -- name: DeleteEntity :exec DELETE FROM entity_instances -WHERE id = $1 AND project_id = $2; +WHERE id = sqlc.arg(id) + AND project_id = sqlc.arg(project_id) + AND provider_id = sqlc.arg(provider_id); -- GetEntityByID retrieves an entity by its ID for a project or hierarchy of projects. -- name: GetEntityByID :one SELECT * FROM entity_instances -WHERE entity_instances.id = $1 +WHERE id = sqlc.arg(id) + AND project_id = sqlc.arg(project_id) + AND provider_id = sqlc.arg(provider_id) LIMIT 1; --- GetEntityByName retrieves an entity by its name for a project or hierarchy of projects. +-- GetEntityByName retrieves an entity by its name securely. -- name: GetEntityByName :one SELECT * FROM entity_instances -WHERE - entity_instances.name = sqlc.arg(name) - AND entity_instances.project_id = $1 - AND entity_instances.entity_type = $2 - AND entity_instances.provider_id = sqlc.arg(provider_id) +WHERE name = sqlc.arg(name) + AND entity_type = sqlc.arg(entity_type) + AND project_id = sqlc.arg(project_id) + AND provider_id = sqlc.arg(provider_id) LIMIT 1; --- GetEntitiesByType retrieves all entities of a given type for a project or hierarchy of projects. +-- GetEntitiesByType retrieves all entities of a given type for a project hierarchy. -- this is how one would get all repositories, artifacts, etc. - -- name: GetEntitiesByType :many SELECT * FROM entity_instances -WHERE entity_instances.entity_type = $1 - AND entity_instances.provider_id = sqlc.arg(provider_id) - AND entity_instances.project_id = ANY(sqlc.arg(projects)::uuid[]); +WHERE entity_type = sqlc.arg(entity_type) + AND provider_id = sqlc.arg(provider_id) + AND project_id = ANY(sqlc.arg(projects)::uuid[]); --- ListEntitiesAfterID retrieves entities of a given type after a cursor ID, for pagination. +-- ListEntitiesAfterID retrieves entities for pagination securely. -- This is used for cursor-based iteration over all entities (e.g., in the reminder service). - -- name: ListEntitiesAfterID :many SELECT * FROM entity_instances -WHERE entity_instances.entity_type = $1 - AND entity_instances.id > $2 -ORDER BY entity_instances.id +WHERE entity_type = sqlc.arg(entity_type) + AND id > sqlc.arg(id) + AND provider_id = sqlc.arg(provider_id) + AND project_id = ANY(sqlc.arg(projects)::uuid[]) +ORDER BY id LIMIT sqlc.arg('limit')::bigint; --- EntityExistsAfterID checks if any entity of a given type exists after a cursor ID. - +-- EntityExistsAfterID checks if any entity exists after a cursor ID securely. -- name: EntityExistsAfterID :one SELECT EXISTS ( SELECT 1 FROM entity_instances - WHERE entity_instances.entity_type = $1 - AND entity_instances.id > $2 + WHERE entity_type = sqlc.arg(entity_type) + AND id > sqlc.arg(id) + AND provider_id = sqlc.arg(provider_id) + AND project_id = ANY(sqlc.arg(projects)::uuid[]) ) AS exists; --- GetEntitiesByProvider retrieves all entities of a given provider. +-- GetEntitiesByProvider retrieves all entities of a given provider scoped by project hierarchy. -- this is how one would get all repositories, artifacts, etc. for a given provider. - -- name: GetEntitiesByProvider :many SELECT * FROM entity_instances -WHERE entity_instances.provider_id = $1; +WHERE provider_id = sqlc.arg(provider_id) + AND project_id = ANY(sqlc.arg(projects)::uuid[]); -- GetEntitiesByProjectHierarchy retrieves all entities for a project or hierarchy of projects. - -- name: GetEntitiesByProjectHierarchy :many SELECT * FROM entity_instances -WHERE entity_instances.project_id = ANY(sqlc.arg(projects)::uuid[]); - --- CountEntitiesByType counts all entities of a given type (across all projects/providers). +WHERE project_id = ANY(sqlc.arg(projects)::uuid[]); +-- CountEntitiesByType counts all entities of a given type (Global admin metric). -- name: CountEntitiesByType :one SELECT COUNT(*) FROM entity_instances -WHERE entity_instances.entity_type = $1; +WHERE entity_type = sqlc.arg(entity_type); -- CountEntitiesByTypeAndProject counts entities of a given type for a specific project. - -- name: CountEntitiesByTypeAndProject :one SELECT COUNT(*) FROM entity_instances -WHERE entity_instances.entity_type = $1 AND entity_instances.project_id = $2; +WHERE entity_type = sqlc.arg(entity_type) + AND project_id = sqlc.arg(project_id); +-- GetProperty retrieves a single property, using a JOIN to ensure the caller owns the parent entity -- name: GetProperty :one -SELECT * FROM properties -WHERE entity_id = $1 AND key = $2; +SELECT p.* FROM properties p +JOIN entity_instances ei ON p.entity_id = ei.id +WHERE p.entity_id = sqlc.arg(entity_id) + AND p.key = sqlc.arg(key) + AND ei.project_id = sqlc.arg(project_id) + AND ei.provider_id = sqlc.arg(provider_id); +-- DeleteProperty deletes a property, using USING to ensure the caller owns the parent entity. -- name: DeleteProperty :exec -DELETE FROM properties -WHERE entity_id = $1 AND key = $2; +DELETE FROM properties p +USING entity_instances ei +WHERE p.entity_id = ei.id + AND p.entity_id = sqlc.arg(entity_id) + AND p.key = sqlc.arg(key) + AND ei.project_id = sqlc.arg(project_id) + AND ei.provider_id = sqlc.arg(provider_id); +-- UpsertProperty upserts a property. +-- NOTE: Ownership MUST be verified in Go (e.g. via GetEntityByID) before executing this statement. -- name: UpsertProperty :one INSERT INTO properties ( entity_id, key, value, updated_at -) VALUES ($1, $2, $3, NOW()) +) VALUES ( + sqlc.arg(entity_id), + sqlc.arg(key), + sqlc.arg(value), + NOW() +) ON CONFLICT (entity_id, key) DO UPDATE - SET - value = sqlc.arg(value), - updated_at = NOW() +SET + value = sqlc.arg(value), + updated_at = NOW() RETURNING *; +-- GetAllPropertiesForEntity retrieves all properties for one entity, strictly bounded. -- name: GetAllPropertiesForEntity :many -SELECT * FROM properties -WHERE entity_id = $1; - +SELECT p.* FROM properties p +JOIN entity_instances ei ON p.entity_id = ei.id +WHERE p.entity_id = sqlc.arg(entity_id) + AND ei.project_id = sqlc.arg(project_id) + AND ei.provider_id = sqlc.arg(provider_id); + +-- GetPropertiesForEntities retrieves properties for multiple entities in bulk +-- name: GetPropertiesForEntities :many +SELECT p.* FROM properties p +JOIN entity_instances ei ON p.entity_id = ei.id +WHERE p.entity_id = ANY(sqlc.arg(entity_ids)::uuid[]) + AND ei.project_id = ANY(sqlc.arg(projects)::uuid[]) + AND ei.provider_id = sqlc.arg(provider_id); + +-- DeleteAllPropertiesForEntity deletes all properties for an entity securely. -- name: DeleteAllPropertiesForEntity :exec -DELETE FROM properties -WHERE entity_id = $1; - +DELETE FROM properties p +USING entity_instances ei +WHERE p.entity_id = ei.id + AND p.entity_id = sqlc.arg(entity_id) + AND ei.project_id = sqlc.arg(project_id) + AND ei.provider_id = sqlc.arg(provider_id); + +-- GetTypedEntitiesByProperty retrieves entities matching a specific property JSONB query securely. -- name: GetTypedEntitiesByProperty :many SELECT ei.* FROM entity_instances ei - JOIN properties p ON ei.id = p.entity_id +JOIN properties p ON ei.id = p.entity_id WHERE ei.entity_type = sqlc.arg(entity_type) AND (sqlc.arg(project_id)::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR ei.project_id = sqlc.arg(project_id)) AND (sqlc.arg(provider_id)::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR ei.provider_id = sqlc.arg(provider_id)) diff --git a/database/query/entity_execution_lock.sql b/database/query/entity_execution_lock.sql index 7f0e8398c3..132911194d 100644 --- a/database/query/entity_execution_lock.sql +++ b/database/query/entity_execution_lock.sql @@ -51,8 +51,21 @@ RETURNING *; -- name: FlushCache :one DELETE FROM flush_cache -WHERE entity_instance_id= $1 +WHERE flush_cache.entity_instance_id = sqlc.arg(entity_instance_id) + AND flush_cache.project_id = sqlc.arg(project_id) + AND EXISTS ( + SELECT 1 FROM entity_instances ei + WHERE ei.id = flush_cache.entity_instance_id + AND ei.provider_id = sqlc.arg(provider_id) + ) RETURNING *; -- name: ListFlushCache :many -SELECT * FROM flush_cache; \ No newline at end of file +SELECT + fc.entity, + fc.project_id, + fc.entity_instance_id, + fc.queued_at, + ei.provider_id +FROM flush_cache fc +JOIN entity_instances ei ON fc.entity_instance_id = ei.id; \ No newline at end of file diff --git a/database/query/eval_history.sql b/database/query/eval_history.sql index aca20f3cd4..6667e22b72 100644 --- a/database/query/eval_history.sql +++ b/database/query/eval_history.sql @@ -114,6 +114,7 @@ SELECT s.id::uuid AS evaluation_id, -- entity id ere.entity_instance_id as entity_id, j.id as project_id, + ei.provider_id, -- rule type, name, and profile rt.name AS rule_type, ri.name AS rule_name, diff --git a/database/query/profile_status.sql b/database/query/profile_status.sql index a57762d8d8..b4b0ef8656 100644 --- a/database/query/profile_status.sql +++ b/database/query/profile_status.sql @@ -90,6 +90,7 @@ SELECT ere.entity_instance_id as entity_id, ei.name as entity_name, ei.project_id as project_id, + ei.provider_id, rt.release_phase as rule_type_release_phase, eo.output AS eval_output FROM latest_evaluation_statuses les diff --git a/internal/controlplane/handlers_artifacts.go b/internal/controlplane/handlers_artifacts.go index 5bbc61827b..72a421aaab 100644 --- a/internal/controlplane/handlers_artifacts.go +++ b/internal/controlplane/handlers_artifacts.go @@ -94,7 +94,7 @@ func (s *Server) GetArtifactByName(ctx context.Context, in *pb.GetArtifactByName } // Fetch the entity with properties - ewp, err := s.props.EntityWithPropertiesByID(ctx, entities[0].ID, nil) + ewp, err := s.props.EntityWithPropertiesByID(ctx, entities[0].ID, projectID, provider.ID, nil) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, status.Errorf(codes.NotFound, "artifact not found") @@ -128,7 +128,7 @@ func (s *Server) GetArtifactByName(ctx context.Context, in *pb.GetArtifactByName return &pb.GetArtifactByNameResponse{ Artifact: pbArtifact, - Versions: nil, // explicitly nil, will probably deprecate that field later + Versions: nil, // explicitly nil, will probably deprecate that field later. }, nil } @@ -137,14 +137,23 @@ func (s *Server) GetArtifactByName(ctx context.Context, in *pb.GetArtifactByName func (s *Server) GetArtifactById(ctx context.Context, in *pb.GetArtifactByIdRequest) (*pb.GetArtifactByIdResponse, error) { entityCtx := engcontext.EntityFromContext(ctx) projectID := entityCtx.Project.ID + providerName := entityCtx.Provider.Name parsedArtifactID, err := uuid.Parse(in.Id) if err != nil { return nil, util.UserVisibleError(codes.InvalidArgument, "invalid artifact ID") } + provider, err := s.providerStore.GetByName(ctx, projectID, providerName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, util.UserVisibleError(codes.NotFound, "provider not found") + } + return nil, status.Errorf(codes.Internal, "cannot get provider: %v", err) + } + // Fetch artifact entity - ewp, err := s.props.EntityWithPropertiesByID(ctx, parsedArtifactID, nil) + ewp, err := s.props.EntityWithPropertiesByID(ctx, parsedArtifactID, projectID, provider.ID, nil) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, status.Errorf(codes.NotFound, "artifact not found") @@ -152,11 +161,6 @@ func (s *Server) GetArtifactById(ctx context.Context, in *pb.GetArtifactByIdRequ return nil, status.Errorf(codes.Unknown, "failed to get artifact: %s", err) } - // Verify the entity belongs to the correct project - if ewp.Entity.ProjectID != projectID { - return nil, status.Errorf(codes.NotFound, "artifact not found") - } - // Verify it's an artifact entity if ewp.Entity.Type != pb.Entity_ENTITY_ARTIFACTS { return nil, status.Errorf(codes.InvalidArgument, "entity is not an artifact") diff --git a/internal/controlplane/handlers_entity_instances.go b/internal/controlplane/handlers_entity_instances.go index d2af091c9a..329dd818b4 100644 --- a/internal/controlplane/handlers_entity_instances.go +++ b/internal/controlplane/handlers_entity_instances.go @@ -95,15 +95,26 @@ func (s *Server) GetEntityById( return nil, util.UserVisibleError(codes.InvalidArgument, "invalid entity ID") } - projectID := GetProjectID(ctx) + entityCtx := engcontext.EntityFromContext(ctx) + projectID := entityCtx.Project.ID + providerName := entityCtx.Provider.Name + + provider, err := s.providerStore.GetByName(ctx, projectID, providerName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, util.UserVisibleError(codes.NotFound, "provider not found") + } + return nil, fmt.Errorf("error getting provider: %w", err) + } // Call service to get entity - entity, err := s.entityService.GetEntityByID(ctx, entityID, projectID) + entity, err := s.entityService.GetEntityByID(ctx, entityID, projectID, provider.ID) if err != nil { return nil, err } // Telemetry logging + logger.BusinessRecord(ctx).Provider = providerName logger.BusinessRecord(ctx).Project = projectID logger.BusinessRecord(ctx).Entity = entityID @@ -168,15 +179,26 @@ func (s *Server) DeleteEntityById( return nil, util.UserVisibleError(codes.InvalidArgument, "invalid entity ID") } - projectID := GetProjectID(ctx) + entityCtx := engcontext.EntityFromContext(ctx) + projectID := entityCtx.Project.ID + providerName := entityCtx.Provider.Name + + provider, err := s.providerStore.GetByName(ctx, projectID, providerName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, util.UserVisibleError(codes.NotFound, "provider not found") + } + return nil, fmt.Errorf("error getting provider: %w", err) + } // Call service to delete entity - err = s.entityService.DeleteEntityByID(ctx, entityID, projectID) + err = s.entityService.DeleteEntityByID(ctx, entityID, projectID, provider.ID) if err != nil { return nil, err } // Telemetry logging + logger.BusinessRecord(ctx).Provider = providerName // Added provider telemetry! logger.BusinessRecord(ctx).Project = projectID logger.BusinessRecord(ctx).Entity = entityID diff --git a/internal/controlplane/handlers_evalstatus.go b/internal/controlplane/handlers_evalstatus.go index 41a78b8fe9..f389e976d8 100644 --- a/internal/controlplane/handlers_evalstatus.go +++ b/internal/controlplane/handlers_evalstatus.go @@ -420,7 +420,7 @@ func (s *Server) sortEntitiesEvaluationStatus( continue } - efp, err := psc.EntityWithPropertiesByID(ctx, e.EntityID, nil) + efp, err := psc.EntityWithPropertiesByID(ctx, e.EntityID, p.Profile.ProjectID, e.ProviderID, nil) if err != nil { if errors.Is(err, propSvc.ErrEntityNotFound) { // If the entity is not found, log and skip diff --git a/internal/controlplane/handlers_evalstatus_test.go b/internal/controlplane/handlers_evalstatus_test.go index bd66141d2c..c3f5adc179 100644 --- a/internal/controlplane/handlers_evalstatus_test.go +++ b/internal/controlplane/handlers_evalstatus_test.go @@ -612,6 +612,7 @@ func TestListEvaluationResultsIncludeOutputs(t *testing.T) { profileID := uuid.New() entityID := uuid.New() ruleTypeID := uuid.New() + providerID := uuid.New() outputJSON := json.RawMessage(`{"finding":"something_wrong"}`) expectedOutput := &structpb.Value{} @@ -649,13 +650,15 @@ func TestListEvaluationResultsIncludeOutputs(t *testing.T) { efp := entmodels.NewEntityWithPropertiesFromInstance( entmodels.EntityInstance{ - ID: entityID, - Type: minderv1.Entity_ENTITY_REPOSITORIES, - Name: "mindersec/minder", + ID: entityID, + Type: minderv1.Entity_ENTITY_REPOSITORIES, + Name: "mindersec/minder", + ProjectID: projectID, + ProviderID: providerID, }, nil) mockProps.EXPECT(). - EntityWithPropertiesByID(gomock.Any(), entityID, gomock.Any()). + EntityWithPropertiesByID(gomock.Any(), entityID, gomock.Any(), gomock.Any(), gomock.Any()). Return(efp, nil) mockProps.EXPECT(). RetrieveAllPropertiesForEntity(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). @@ -682,6 +685,7 @@ func TestListEvaluationResultsIncludeOutputs(t *testing.T) { EntityID: entityID, EntityName: "mindersec/minder", ProjectID: projectID, + ProviderID: providerID, RuleTypeID: ruleTypeID, RuleName: "my_rule", RuleTypeName: "rule_type_a", diff --git a/internal/controlplane/handlers_profile.go b/internal/controlplane/handlers_profile.go index 93765fe322..793c537621 100644 --- a/internal/controlplane/handlers_profile.go +++ b/internal/controlplane/handlers_profile.go @@ -340,7 +340,13 @@ func (s *Server) getRuleEvalStatus( // the caller just ignores allt the errors anyway, so we don't start a transaction as the integrity issues // would not be discovered anyway - efp, err := s.props.EntityWithPropertiesByID(ctx, dbRuleEvalStat.EntityID, nil) + efp, err := s.props.EntityWithPropertiesByID( + ctx, + dbRuleEvalStat.EntityID, + dbRuleEvalStat.ProjectID, + dbRuleEvalStat.ProviderID, + nil, + ) if err != nil { return nil, fmt.Errorf("error fetching entity for properties: %w", err) } diff --git a/internal/controlplane/handlers_providers_test.go b/internal/controlplane/handlers_providers_test.go index a0a162aa9d..df382b3088 100644 --- a/internal/controlplane/handlers_providers_test.go +++ b/internal/controlplane/handlers_providers_test.go @@ -559,7 +559,7 @@ func TestDeleteProvider(t *testing.T) { mockprops := propSvc.NewMockPropertiesService(ctrl) mockprops.EXPECT(). - EntityWithPropertiesByID(gomock.Any(), gomock.Any(), nil). + EntityWithPropertiesByID(gomock.Any(), gomock.Any(), projectID, providerID, nil). Return(models.NewEntityWithPropertiesFromInstance( models.EntityInstance{}, nil), nil) @@ -682,7 +682,7 @@ func TestDeleteProviderByID(t *testing.T) { mockprops := propSvc.NewMockPropertiesService(ctrl) mockprops.EXPECT(). - EntityWithPropertiesByID(gomock.Any(), gomock.Any(), nil). + EntityWithPropertiesByID(gomock.Any(), gomock.Any(), projectID, providerID, nil). Return(models.NewEntityWithPropertiesFromInstance( models.EntityInstance{}, nil), nil) diff --git a/internal/controlplane/handlers_reconciliationtasks.go b/internal/controlplane/handlers_reconciliationtasks.go index 9a6e008baf..73bf7cace0 100644 --- a/internal/controlplane/handlers_reconciliationtasks.go +++ b/internal/controlplane/handlers_reconciliationtasks.go @@ -69,7 +69,7 @@ func (s *Server) CreateEntityReconciliationTask(ctx context.Context, var msg *message.Message var topic string - msg, err = getRepositoryReconciliationMessage(ctx, s.store, entity.GetId(), entityCtx) + msg, err = getRepositoryReconciliationMessage(ctx, s.store, entity.GetId(), entityCtx, dbProvider.ID) if err != nil { return nil, err } @@ -83,26 +83,31 @@ func (s *Server) CreateEntityReconciliationTask(ctx context.Context, return &pb.CreateEntityReconciliationTaskResponse{}, nil } -func getRepositoryReconciliationMessage(ctx context.Context, store db.Store, - repoIdString string, entityCtx engcontext.EntityContext) (*message.Message, error) { +func getRepositoryReconciliationMessage( + ctx context.Context, + store db.Store, + repoIdString string, + entityCtx engcontext.EntityContext, + providerID uuid.UUID, +) (*message.Message, error) { repoUUID, err := uuid.Parse(repoIdString) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "error parsing repository id: %v", err) } // Fetch entity by ID - ent, err := store.GetEntityByID(ctx, repoUUID) + ent, err := store.GetEntityByID(ctx, db.GetEntityByIDParams{ + ID: repoUUID, + ProjectID: entityCtx.Project.ID, + ProviderID: providerID, + }) + if errors.Is(err, sql.ErrNoRows) { return nil, status.Errorf(codes.NotFound, "repository not found") } else if err != nil { return nil, status.Errorf(codes.Internal, "cannot read repository: %v", err) } - // Verify project matches - if ent.ProjectID != entityCtx.Project.ID { - return nil, status.Errorf(codes.NotFound, "repository not found") - } - // Telemetry logging logger.BusinessRecord(ctx).ProviderID = ent.ProviderID logger.BusinessRecord(ctx).Project = ent.ProjectID diff --git a/internal/controlplane/handlers_reconciliationtasks_test.go b/internal/controlplane/handlers_reconciliationtasks_test.go index 79f9295f6b..2d47df1f13 100644 --- a/internal/controlplane/handlers_reconciliationtasks_test.go +++ b/internal/controlplane/handlers_reconciliationtasks_test.go @@ -52,13 +52,18 @@ func TestServer_CreateRepositoryReconciliationTask(t *testing.T) { setup: func(store *mockdb.MockStore, entityContext *engcontext.EntityContext) { projId := entityContext.Project.ID prov := entityContext.Provider.Name - setupTestingEntityContextValidation(store, projId, prov, uuid.New()) + provId := uuid.New() + setupTestingEntityContextValidation(store, projId, prov, provId) store.EXPECT(). - GetEntityByID(gomock.Any(), repoUuid). + GetEntityByID(gomock.Any(), db.GetEntityByIDParams{ + ID: repoUuid, + ProjectID: projId, + ProviderID: provId, + }). Return(db.EntityInstance{ ID: repoUuid, EntityType: db.EntitiesRepository, - ProviderID: uuid.New(), + ProviderID: provId, ProjectID: projId, }, nil) }, @@ -88,17 +93,21 @@ func TestServer_CreateRepositoryReconciliationTask(t *testing.T) { store.EXPECT(). GetEntityByName(gomock.Any(), db.GetEntityByNameParams{ ProjectID: projId, - EntityType: "repository", + EntityType: db.EntitiesRepository, Name: "my/repo", ProviderID: provId, }). - Return(db.EntityInstance{ID: repoUuid}, nil) + Return(db.EntityInstance{ID: repoUuid, ProjectID: projId, ProviderID: provId}, nil) store.EXPECT(). - GetEntityByID(gomock.Any(), repoUuid). + GetEntityByID(gomock.Any(), db.GetEntityByIDParams{ + ID: repoUuid, + ProjectID: projId, + ProviderID: provId, + }). Return(db.EntityInstance{ ID: repoUuid, EntityType: db.EntitiesRepository, - ProviderID: uuid.New(), + ProviderID: provId, ProjectID: projId, }, nil) }, @@ -120,6 +129,7 @@ func TestServer_CreateRepositoryReconciliationTask(t *testing.T) { }, setup: func(store *mockdb.MockStore, entityContext *engcontext.EntityContext) { projId := entityContext.Project.ID + provId := uuid.New() store.EXPECT(). GetProjectByID(gomock.Any(), projId). Return(db.Project{ID: projId}, nil) @@ -131,13 +141,17 @@ func TestServer_CreateRepositoryReconciliationTask(t *testing.T) { Name: sql.NullString{String: "", Valid: false}, Projects: []uuid.UUID{projId}, Trait: db.NullProviderType{}, - }).Return([]db.Provider{{Name: ghProvider}}, nil) + }).Return([]db.Provider{{Name: ghProvider, ID: provId}}, nil) store.EXPECT(). - GetEntityByID(gomock.Any(), repoUuid). + GetEntityByID(gomock.Any(), db.GetEntityByIDParams{ + ID: repoUuid, + ProjectID: projId, + ProviderID: provId, + }). Return(db.EntityInstance{ ID: repoUuid, EntityType: db.EntitiesRepository, - ProviderID: uuid.New(), + ProviderID: provId, ProjectID: projId, }, nil) }, @@ -186,9 +200,14 @@ func TestServer_CreateRepositoryReconciliationTask(t *testing.T) { setup: func(store *mockdb.MockStore, entityContext *engcontext.EntityContext) { projId := entityContext.Project.ID prov := entityContext.Provider.Name - setupTestingEntityContextValidation(store, projId, prov, uuid.New()) + provId := uuid.New() + setupTestingEntityContextValidation(store, projId, prov, provId) store.EXPECT(). - GetEntityByID(gomock.Any(), repoUuid). + GetEntityByID(gomock.Any(), db.GetEntityByIDParams{ + ID: repoUuid, + ProjectID: projId, + ProviderID: provId, + }). Return(db.EntityInstance{}, sql.ErrNoRows) }, err: "repository not found", @@ -212,9 +231,14 @@ func TestServer_CreateRepositoryReconciliationTask(t *testing.T) { setup: func(store *mockdb.MockStore, entityContext *engcontext.EntityContext) { projId := entityContext.Project.ID prov := entityContext.Provider.Name - setupTestingEntityContextValidation(store, projId, prov, uuid.New()) + provId := uuid.New() + setupTestingEntityContextValidation(store, projId, prov, provId) store.EXPECT(). - GetEntityByID(gomock.Any(), repoUuid). + GetEntityByID(gomock.Any(), db.GetEntityByIDParams{ + ID: repoUuid, + ProjectID: projId, + ProviderID: provId, + }). Return(db.EntityInstance{}, sql.ErrConnDone) }, err: sql.ErrConnDone.Error(), diff --git a/internal/db/entities.sql.go b/internal/db/entities.sql.go index fe75ce30cf..519e81926a 100644 --- a/internal/db/entities.sql.go +++ b/internal/db/entities.sql.go @@ -14,12 +14,11 @@ import ( ) const countEntitiesByType = `-- name: CountEntitiesByType :one - SELECT COUNT(*) FROM entity_instances -WHERE entity_instances.entity_type = $1 +WHERE entity_type = $1 ` -// CountEntitiesByType counts all entities of a given type (across all projects/providers). +// CountEntitiesByType counts all entities of a given type (Global admin metric). func (q *Queries) CountEntitiesByType(ctx context.Context, entityType Entities) (int64, error) { row := q.db.QueryRowContext(ctx, countEntitiesByType, entityType) var count int64 @@ -28,9 +27,9 @@ func (q *Queries) CountEntitiesByType(ctx context.Context, entityType Entities) } const countEntitiesByTypeAndProject = `-- name: CountEntitiesByTypeAndProject :one - SELECT COUNT(*) FROM entity_instances -WHERE entity_instances.entity_type = $1 AND entity_instances.project_id = $2 +WHERE entity_type = $1 + AND project_id = $2 ` type CountEntitiesByTypeAndProjectParams struct { @@ -47,14 +46,19 @@ func (q *Queries) CountEntitiesByTypeAndProject(ctx context.Context, arg CountEn } const createEntity = `-- name: CreateEntity :one - INSERT INTO entity_instances ( entity_type, name, project_id, provider_id, originated_from -) VALUES ($1, $2, $3, $4, $5) +) VALUES ( + $1, + $2, + $3, + $4, + $5 +) RETURNING id, entity_type, name, project_id, provider_id, created_at, originated_from ` @@ -89,7 +93,6 @@ func (q *Queries) CreateEntity(ctx context.Context, arg CreateEntityParams) (Ent } const createEntityWithID = `-- name: CreateEntityWithID :one - INSERT INTO entity_instances ( id, entity_type, @@ -97,7 +100,14 @@ INSERT INTO entity_instances ( project_id, provider_id, originated_from -) VALUES ($1, $2, $3, $4, $5, $6) +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6 +) RETURNING id, entity_type, name, project_id, provider_id, created_at, originated_from ` @@ -134,7 +144,6 @@ func (q *Queries) CreateEntityWithID(ctx context.Context, arg CreateEntityWithID } const createOrEnsureEntityByID = `-- name: CreateOrEnsureEntityByID :one - INSERT INTO entity_instances ( id, entity_type, @@ -142,7 +151,14 @@ INSERT INTO entity_instances ( project_id, provider_id, originated_from -) VALUES ($1, $2, $3, $4, $5, $6) +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6 +) ON CONFLICT (id) DO UPDATE SET id = entity_instances.id -- This is a "noop" update to ensure the RETURNING clause works @@ -182,76 +198,121 @@ func (q *Queries) CreateOrEnsureEntityByID(ctx context.Context, arg CreateOrEnsu } const deleteAllPropertiesForEntity = `-- name: DeleteAllPropertiesForEntity :exec -DELETE FROM properties -WHERE entity_id = $1 +DELETE FROM properties p +USING entity_instances ei +WHERE p.entity_id = ei.id + AND p.entity_id = $1 + AND ei.project_id = $2 + AND ei.provider_id = $3 ` -func (q *Queries) DeleteAllPropertiesForEntity(ctx context.Context, entityID uuid.UUID) error { - _, err := q.db.ExecContext(ctx, deleteAllPropertiesForEntity, entityID) +type DeleteAllPropertiesForEntityParams struct { + EntityID uuid.UUID `json:"entity_id"` + ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` +} + +// DeleteAllPropertiesForEntity deletes all properties for an entity securely. +func (q *Queries) DeleteAllPropertiesForEntity(ctx context.Context, arg DeleteAllPropertiesForEntityParams) error { + _, err := q.db.ExecContext(ctx, deleteAllPropertiesForEntity, arg.EntityID, arg.ProjectID, arg.ProviderID) return err } const deleteEntity = `-- name: DeleteEntity :exec DELETE FROM entity_instances -WHERE id = $1 AND project_id = $2 +WHERE id = $1 + AND project_id = $2 + AND provider_id = $3 ` type DeleteEntityParams struct { - ID uuid.UUID `json:"id"` - ProjectID uuid.UUID `json:"project_id"` + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` } -// DeleteEntity removes an entity from the entity_instances table for a project. +// DeleteEntity removes an entity from the entity_instances table securely. func (q *Queries) DeleteEntity(ctx context.Context, arg DeleteEntityParams) error { - _, err := q.db.ExecContext(ctx, deleteEntity, arg.ID, arg.ProjectID) + _, err := q.db.ExecContext(ctx, deleteEntity, arg.ID, arg.ProjectID, arg.ProviderID) return err } const deleteProperty = `-- name: DeleteProperty :exec -DELETE FROM properties -WHERE entity_id = $1 AND key = $2 +DELETE FROM properties p +USING entity_instances ei +WHERE p.entity_id = ei.id + AND p.entity_id = $1 + AND p.key = $2 + AND ei.project_id = $3 + AND ei.provider_id = $4 ` type DeletePropertyParams struct { - EntityID uuid.UUID `json:"entity_id"` - Key string `json:"key"` + EntityID uuid.UUID `json:"entity_id"` + Key string `json:"key"` + ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` } +// DeleteProperty deletes a property, using USING to ensure the caller owns the parent entity. func (q *Queries) DeleteProperty(ctx context.Context, arg DeletePropertyParams) error { - _, err := q.db.ExecContext(ctx, deleteProperty, arg.EntityID, arg.Key) + _, err := q.db.ExecContext(ctx, deleteProperty, + arg.EntityID, + arg.Key, + arg.ProjectID, + arg.ProviderID, + ) return err } const entityExistsAfterID = `-- name: EntityExistsAfterID :one - SELECT EXISTS ( SELECT 1 FROM entity_instances - WHERE entity_instances.entity_type = $1 - AND entity_instances.id > $2 + WHERE entity_type = $1 + AND id > $2 + AND provider_id = $3 + AND project_id = ANY($4::uuid[]) ) AS exists ` type EntityExistsAfterIDParams struct { - EntityType Entities `json:"entity_type"` - ID uuid.UUID `json:"id"` + EntityType Entities `json:"entity_type"` + ID uuid.UUID `json:"id"` + ProviderID uuid.UUID `json:"provider_id"` + Projects []uuid.UUID `json:"projects"` } -// EntityExistsAfterID checks if any entity of a given type exists after a cursor ID. +// EntityExistsAfterID checks if any entity exists after a cursor ID securely. func (q *Queries) EntityExistsAfterID(ctx context.Context, arg EntityExistsAfterIDParams) (bool, error) { - row := q.db.QueryRowContext(ctx, entityExistsAfterID, arg.EntityType, arg.ID) + row := q.db.QueryRowContext(ctx, entityExistsAfterID, + arg.EntityType, + arg.ID, + arg.ProviderID, + pq.Array(arg.Projects), + ) var exists bool err := row.Scan(&exists) return exists, err } const getAllPropertiesForEntity = `-- name: GetAllPropertiesForEntity :many -SELECT id, entity_id, key, value, updated_at FROM properties -WHERE entity_id = $1 +SELECT p.id, p.entity_id, p.key, p.value, p.updated_at FROM properties p +JOIN entity_instances ei ON p.entity_id = ei.id +WHERE p.entity_id = $1 + AND ei.project_id = $2 + AND ei.provider_id = $3 ` -func (q *Queries) GetAllPropertiesForEntity(ctx context.Context, entityID uuid.UUID) ([]Property, error) { - rows, err := q.db.QueryContext(ctx, getAllPropertiesForEntity, entityID) +type GetAllPropertiesForEntityParams struct { + EntityID uuid.UUID `json:"entity_id"` + ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` +} + +// GetAllPropertiesForEntity retrieves all properties for one entity, strictly bounded. +func (q *Queries) GetAllPropertiesForEntity(ctx context.Context, arg GetAllPropertiesForEntityParams) ([]Property, error) { + rows, err := q.db.QueryContext(ctx, getAllPropertiesForEntity, arg.EntityID, arg.ProjectID, arg.ProviderID) if err != nil { return nil, err } @@ -280,9 +341,8 @@ func (q *Queries) GetAllPropertiesForEntity(ctx context.Context, entityID uuid.U } const getEntitiesByProjectHierarchy = `-- name: GetEntitiesByProjectHierarchy :many - SELECT id, entity_type, name, project_id, provider_id, created_at, originated_from FROM entity_instances -WHERE entity_instances.project_id = ANY($1::uuid[]) +WHERE project_id = ANY($1::uuid[]) ` // GetEntitiesByProjectHierarchy retrieves all entities for a project or hierarchy of projects. @@ -318,15 +378,20 @@ func (q *Queries) GetEntitiesByProjectHierarchy(ctx context.Context, projects [] } const getEntitiesByProvider = `-- name: GetEntitiesByProvider :many - SELECT id, entity_type, name, project_id, provider_id, created_at, originated_from FROM entity_instances -WHERE entity_instances.provider_id = $1 +WHERE provider_id = $1 + AND project_id = ANY($2::uuid[]) ` -// GetEntitiesByProvider retrieves all entities of a given provider. +type GetEntitiesByProviderParams struct { + ProviderID uuid.UUID `json:"provider_id"` + Projects []uuid.UUID `json:"projects"` +} + +// GetEntitiesByProvider retrieves all entities of a given provider scoped by project hierarchy. // this is how one would get all repositories, artifacts, etc. for a given provider. -func (q *Queries) GetEntitiesByProvider(ctx context.Context, providerID uuid.UUID) ([]EntityInstance, error) { - rows, err := q.db.QueryContext(ctx, getEntitiesByProvider, providerID) +func (q *Queries) GetEntitiesByProvider(ctx context.Context, arg GetEntitiesByProviderParams) ([]EntityInstance, error) { + rows, err := q.db.QueryContext(ctx, getEntitiesByProvider, arg.ProviderID, pq.Array(arg.Projects)) if err != nil { return nil, err } @@ -357,11 +422,10 @@ func (q *Queries) GetEntitiesByProvider(ctx context.Context, providerID uuid.UUI } const getEntitiesByType = `-- name: GetEntitiesByType :many - SELECT id, entity_type, name, project_id, provider_id, created_at, originated_from FROM entity_instances -WHERE entity_instances.entity_type = $1 - AND entity_instances.provider_id = $2 - AND entity_instances.project_id = ANY($3::uuid[]) +WHERE entity_type = $1 + AND provider_id = $2 + AND project_id = ANY($3::uuid[]) ` type GetEntitiesByTypeParams struct { @@ -370,7 +434,7 @@ type GetEntitiesByTypeParams struct { Projects []uuid.UUID `json:"projects"` } -// GetEntitiesByType retrieves all entities of a given type for a project or hierarchy of projects. +// GetEntitiesByType retrieves all entities of a given type for a project hierarchy. // this is how one would get all repositories, artifacts, etc. func (q *Queries) GetEntitiesByType(ctx context.Context, arg GetEntitiesByTypeParams) ([]EntityInstance, error) { rows, err := q.db.QueryContext(ctx, getEntitiesByType, arg.EntityType, arg.ProviderID, pq.Array(arg.Projects)) @@ -405,13 +469,21 @@ func (q *Queries) GetEntitiesByType(ctx context.Context, arg GetEntitiesByTypePa const getEntityByID = `-- name: GetEntityByID :one SELECT id, entity_type, name, project_id, provider_id, created_at, originated_from FROM entity_instances -WHERE entity_instances.id = $1 +WHERE id = $1 + AND project_id = $2 + AND provider_id = $3 LIMIT 1 ` +type GetEntityByIDParams struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` +} + // GetEntityByID retrieves an entity by its ID for a project or hierarchy of projects. -func (q *Queries) GetEntityByID(ctx context.Context, id uuid.UUID) (EntityInstance, error) { - row := q.db.QueryRowContext(ctx, getEntityByID, id) +func (q *Queries) GetEntityByID(ctx context.Context, arg GetEntityByIDParams) (EntityInstance, error) { + row := q.db.QueryRowContext(ctx, getEntityByID, arg.ID, arg.ProjectID, arg.ProviderID) var i EntityInstance err := row.Scan( &i.ID, @@ -427,27 +499,26 @@ func (q *Queries) GetEntityByID(ctx context.Context, id uuid.UUID) (EntityInstan const getEntityByName = `-- name: GetEntityByName :one SELECT id, entity_type, name, project_id, provider_id, created_at, originated_from FROM entity_instances -WHERE - entity_instances.name = $3 - AND entity_instances.project_id = $1 - AND entity_instances.entity_type = $2 - AND entity_instances.provider_id = $4 +WHERE name = $1 + AND entity_type = $2 + AND project_id = $3 + AND provider_id = $4 LIMIT 1 ` type GetEntityByNameParams struct { - ProjectID uuid.UUID `json:"project_id"` - EntityType Entities `json:"entity_type"` Name string `json:"name"` + EntityType Entities `json:"entity_type"` + ProjectID uuid.UUID `json:"project_id"` ProviderID uuid.UUID `json:"provider_id"` } -// GetEntityByName retrieves an entity by its name for a project or hierarchy of projects. +// GetEntityByName retrieves an entity by its name securely. func (q *Queries) GetEntityByName(ctx context.Context, arg GetEntityByNameParams) (EntityInstance, error) { row := q.db.QueryRowContext(ctx, getEntityByName, - arg.ProjectID, - arg.EntityType, arg.Name, + arg.EntityType, + arg.ProjectID, arg.ProviderID, ) var i EntityInstance @@ -463,18 +534,74 @@ func (q *Queries) GetEntityByName(ctx context.Context, arg GetEntityByNameParams return i, err } +const getPropertiesForEntities = `-- name: GetPropertiesForEntities :many +SELECT p.id, p.entity_id, p.key, p.value, p.updated_at FROM properties p +JOIN entity_instances ei ON p.entity_id = ei.id +WHERE p.entity_id = ANY($1::uuid[]) + AND ei.project_id = ANY($2::uuid[]) + AND ei.provider_id = $3 +` + +type GetPropertiesForEntitiesParams struct { + EntityIds []uuid.UUID `json:"entity_ids"` + Projects []uuid.UUID `json:"projects"` + ProviderID uuid.UUID `json:"provider_id"` +} + +// GetPropertiesForEntities retrieves properties for multiple entities in bulk +func (q *Queries) GetPropertiesForEntities(ctx context.Context, arg GetPropertiesForEntitiesParams) ([]Property, error) { + rows, err := q.db.QueryContext(ctx, getPropertiesForEntities, pq.Array(arg.EntityIds), pq.Array(arg.Projects), arg.ProviderID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Property{} + for rows.Next() { + var i Property + if err := rows.Scan( + &i.ID, + &i.EntityID, + &i.Key, + &i.Value, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getProperty = `-- name: GetProperty :one -SELECT id, entity_id, key, value, updated_at FROM properties -WHERE entity_id = $1 AND key = $2 +SELECT p.id, p.entity_id, p.key, p.value, p.updated_at FROM properties p +JOIN entity_instances ei ON p.entity_id = ei.id +WHERE p.entity_id = $1 + AND p.key = $2 + AND ei.project_id = $3 + AND ei.provider_id = $4 ` type GetPropertyParams struct { - EntityID uuid.UUID `json:"entity_id"` - Key string `json:"key"` + EntityID uuid.UUID `json:"entity_id"` + Key string `json:"key"` + ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` } +// GetProperty retrieves a single property, using a JOIN to ensure the caller owns the parent entity func (q *Queries) GetProperty(ctx context.Context, arg GetPropertyParams) (Property, error) { - row := q.db.QueryRowContext(ctx, getProperty, arg.EntityID, arg.Key) + row := q.db.QueryRowContext(ctx, getProperty, + arg.EntityID, + arg.Key, + arg.ProjectID, + arg.ProviderID, + ) var i Property err := row.Scan( &i.ID, @@ -489,7 +616,7 @@ func (q *Queries) GetProperty(ctx context.Context, arg GetPropertyParams) (Prope const getTypedEntitiesByProperty = `-- name: GetTypedEntitiesByProperty :many SELECT ei.id, ei.entity_type, ei.name, ei.project_id, ei.provider_id, ei.created_at, ei.originated_from FROM entity_instances ei - JOIN properties p ON ei.id = p.entity_id +JOIN properties p ON ei.id = p.entity_id WHERE ei.entity_type = $1 AND ($2::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR ei.project_id = $2) AND ($3::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR ei.provider_id = $3) @@ -505,6 +632,7 @@ type GetTypedEntitiesByPropertyParams struct { Value json.RawMessage `json:"value"` } +// GetTypedEntitiesByProperty retrieves entities matching a specific property JSONB query securely. func (q *Queries) GetTypedEntitiesByProperty(ctx context.Context, arg GetTypedEntitiesByPropertyParams) ([]EntityInstance, error) { rows, err := q.db.QueryContext(ctx, getTypedEntitiesByProperty, arg.EntityType, @@ -543,24 +671,33 @@ func (q *Queries) GetTypedEntitiesByProperty(ctx context.Context, arg GetTypedEn } const listEntitiesAfterID = `-- name: ListEntitiesAfterID :many - SELECT id, entity_type, name, project_id, provider_id, created_at, originated_from FROM entity_instances -WHERE entity_instances.entity_type = $1 - AND entity_instances.id > $2 -ORDER BY entity_instances.id -LIMIT $3::bigint +WHERE entity_type = $1 + AND id > $2 + AND provider_id = $3 + AND project_id = ANY($4::uuid[]) +ORDER BY id +LIMIT $5::bigint ` type ListEntitiesAfterIDParams struct { - EntityType Entities `json:"entity_type"` - ID uuid.UUID `json:"id"` - Limit int64 `json:"limit"` + EntityType Entities `json:"entity_type"` + ID uuid.UUID `json:"id"` + ProviderID uuid.UUID `json:"provider_id"` + Projects []uuid.UUID `json:"projects"` + Limit int64 `json:"limit"` } -// ListEntitiesAfterID retrieves entities of a given type after a cursor ID, for pagination. +// ListEntitiesAfterID retrieves entities for pagination securely. // This is used for cursor-based iteration over all entities (e.g., in the reminder service). func (q *Queries) ListEntitiesAfterID(ctx context.Context, arg ListEntitiesAfterIDParams) ([]EntityInstance, error) { - rows, err := q.db.QueryContext(ctx, listEntitiesAfterID, arg.EntityType, arg.ID, arg.Limit) + rows, err := q.db.QueryContext(ctx, listEntitiesAfterID, + arg.EntityType, + arg.ID, + arg.ProviderID, + pq.Array(arg.Projects), + arg.Limit, + ) if err != nil { return nil, err } @@ -596,11 +733,16 @@ INSERT INTO properties ( key, value, updated_at -) VALUES ($1, $2, $3, NOW()) +) VALUES ( + $1, + $2, + $3, + NOW() +) ON CONFLICT (entity_id, key) DO UPDATE - SET - value = $4, - updated_at = NOW() +SET + value = $3, + updated_at = NOW() RETURNING id, entity_id, key, value, updated_at ` @@ -610,13 +752,10 @@ type UpsertPropertyParams struct { Value json.RawMessage `json:"value"` } +// UpsertProperty upserts a property. +// NOTE: Ownership MUST be verified in Go (e.g. via GetEntityByID) before executing this statement. func (q *Queries) UpsertProperty(ctx context.Context, arg UpsertPropertyParams) (Property, error) { - row := q.db.QueryRowContext(ctx, upsertProperty, - arg.EntityID, - arg.Key, - arg.Value, - arg.Value, - ) + row := q.db.QueryRowContext(ctx, upsertProperty, arg.EntityID, arg.Key, arg.Value) var i Property err := row.Scan( &i.ID, diff --git a/internal/db/entities_test.go b/internal/db/entities_test.go index 488245a305..281404eb5a 100644 --- a/internal/db/entities_test.go +++ b/internal/db/entities_test.go @@ -21,13 +21,14 @@ func Test_EntityCrud(t *testing.T) { org := createRandomOrganization(t) proj := createRandomProject(t, org.ID) prov := createRandomProvider(t, proj.ID) + ctx := context.Background() t.Run("CreateEntity", func(t *testing.T) { t.Parallel() const testRepoName = "testorg/testrepo" - ent, err := testQueries.CreateEntity(context.Background(), CreateEntityParams{ + ent, err := testQueries.CreateEntity(ctx, CreateEntityParams{ EntityType: EntitiesRepository, Name: testRepoName, ProjectID: proj.ID, @@ -43,17 +44,26 @@ func Test_EntityCrud(t *testing.T) { require.Equal(t, ent.ProviderID, prov.ID) require.Equal(t, ent.OriginatedFrom, uuid.NullUUID{}) - entGet, err := testQueries.GetEntityByID(context.Background(), ent.ID) + entGet, err := testQueries.GetEntityByID(ctx, GetEntityByIDParams{ + ID: ent.ID, + ProjectID: proj.ID, + ProviderID: prov.ID, + }) require.NoError(t, err) require.Equal(t, entGet, ent) - err = testQueries.DeleteEntity(context.Background(), DeleteEntityParams{ - ID: ent.ID, - ProjectID: proj.ID, + err = testQueries.DeleteEntity(ctx, DeleteEntityParams{ + ID: ent.ID, + ProjectID: proj.ID, + ProviderID: prov.ID, }) require.NoError(t, err) - entGet, err = testQueries.GetEntityByID(context.Background(), ent.ID) + entGet, err = testQueries.GetEntityByID(ctx, GetEntityByIDParams{ + ID: ent.ID, + ProjectID: proj.ID, + ProviderID: prov.ID, + }) require.ErrorIs(t, err, sql.ErrNoRows) require.Empty(t, entGet) }) @@ -61,7 +71,7 @@ func Test_EntityCrud(t *testing.T) { t.Run("No such entity", func(t *testing.T) { t.Parallel() - ent, err := testQueries.GetEntityByName(context.Background(), GetEntityByNameParams{ + ent, err := testQueries.GetEntityByName(ctx, GetEntityByNameParams{ ProjectID: proj.ID, Name: "garbage/nosuchentity", EntityType: EntitiesRepository, @@ -76,7 +86,7 @@ func Test_EntityCrud(t *testing.T) { const testEntName = "testorg/testent" - entRepo, err := testQueries.CreateEntity(context.Background(), CreateEntityParams{ + entRepo, err := testQueries.CreateEntity(ctx, CreateEntityParams{ EntityType: EntitiesRepository, Name: testEntName, ProjectID: proj.ID, @@ -85,11 +95,8 @@ func Test_EntityCrud(t *testing.T) { }) require.NoError(t, err) require.NotEmpty(t, entRepo) - require.NotEqual(t, entRepo.ID, uuid.Nil) - require.Equal(t, entRepo.EntityType, EntitiesRepository) - require.Equal(t, entRepo.Name, testEntName) - entArtifact, err := testQueries.CreateEntity(context.Background(), CreateEntityParams{ + entArtifact, err := testQueries.CreateEntity(ctx, CreateEntityParams{ EntityType: EntitiesArtifact, Name: testEntName, ProjectID: proj.ID, @@ -101,33 +108,25 @@ func Test_EntityCrud(t *testing.T) { }) require.NoError(t, err) require.NotEmpty(t, entArtifact) - require.NotEqual(t, entArtifact.ID, uuid.Nil) - require.Equal(t, entArtifact.EntityType, EntitiesArtifact) - require.Equal(t, entArtifact.Name, testEntName) - require.Equal(t, entArtifact.OriginatedFrom, uuid.NullUUID{ - UUID: entRepo.ID, - Valid: true, - }) - getRepo, err := testQueries.GetEntityByName(context.Background(), GetEntityByNameParams{ + getRepo, err := testQueries.GetEntityByName(ctx, GetEntityByNameParams{ ProjectID: proj.ID, Name: testEntName, EntityType: EntitiesRepository, ProviderID: prov.ID, }) require.NoError(t, err) - require.NotEmpty(t, getRepo) require.Equal(t, getRepo, entRepo) - getArtifact, err := testQueries.GetEntityByName(context.Background(), GetEntityByNameParams{ + getArtifact, err := testQueries.GetEntityByName(ctx, GetEntityByNameParams{ ProjectID: proj.ID, Name: testEntName, - EntityType: EntitiesRepository, + EntityType: EntitiesArtifact, ProviderID: prov.ID, }) require.NoError(t, err) - require.NotEmpty(t, getRepo) - require.Equal(t, getArtifact, entRepo) + require.NotEmpty(t, getArtifact) + require.Equal(t, getArtifact, entArtifact) }) } @@ -137,13 +136,14 @@ func Test_PropertyCrud(t *testing.T) { org := createRandomOrganization(t) proj := createRandomProject(t, org.ID) prov := createRandomProvider(t, proj.ID) + ctx := context.Background() t.Run("UpsertProperty", func(t *testing.T) { t.Parallel() const testRepoName = "testorg/testrepo_props" - ent, err := testQueries.CreateEntity(context.Background(), CreateEntityParams{ + ent, err := testQueries.CreateEntity(ctx, CreateEntityParams{ EntityType: EntitiesRepository, Name: testRepoName, ProjectID: proj.ID, @@ -153,11 +153,15 @@ func Test_PropertyCrud(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, ent) - dbProp, err := testQueries.GetAllPropertiesForEntity(context.Background(), ent.ID) + dbProp, err := testQueries.GetAllPropertiesForEntity(ctx, GetAllPropertiesForEntityParams{ + EntityID: ent.ID, + ProjectID: proj.ID, + ProviderID: prov.ID, + }) require.NoError(t, err) require.Empty(t, dbProp) - prop, err := testQueries.UpsertPropertyValueV1(context.Background(), UpsertPropertyValueV1Params{ + prop, err := testQueries.UpsertPropertyValueV1(ctx, UpsertPropertyValueV1Params{ EntityID: ent.ID, Key: "testkey", Value: "testvalue", @@ -165,7 +169,7 @@ func Test_PropertyCrud(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, prop) - prop, err = testQueries.UpsertPropertyValueV1(context.Background(), UpsertPropertyValueV1Params{ + prop, err = testQueries.UpsertPropertyValueV1(ctx, UpsertPropertyValueV1Params{ EntityID: ent.ID, Key: "anotherkey", Value: "anothervalue", @@ -173,24 +177,32 @@ func Test_PropertyCrud(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, prop) - dbProp, err = testQueries.GetAllPropertiesForEntity(context.Background(), ent.ID) + dbProp, err = testQueries.GetAllPropertiesForEntity(ctx, GetAllPropertiesForEntityParams{ + EntityID: ent.ID, + ProjectID: proj.ID, + ProviderID: prov.ID, + }) require.NoError(t, err) require.Len(t, dbProp, 2) require.Equal(t, "testvalue", propertyByKey(t, dbProp, "testkey")) require.Equal(t, "anothervalue", propertyByKey(t, dbProp, "anotherkey")) - keyVal, err := testQueries.GetProperty(context.Background(), GetPropertyParams{ - EntityID: ent.ID, - Key: "testkey", + keyVal, err := testQueries.GetProperty(ctx, GetPropertyParams{ + EntityID: ent.ID, + Key: "testkey", + ProjectID: proj.ID, + ProviderID: prov.ID, }) require.NoError(t, err) value, err := PropValueFromDbV1(keyVal.Value) require.NoError(t, err) require.Equal(t, "testvalue", value) - anotherKeyVal, err := testQueries.GetProperty(context.Background(), GetPropertyParams{ - EntityID: ent.ID, - Key: "anotherkey", + anotherKeyVal, err := testQueries.GetProperty(ctx, GetPropertyParams{ + EntityID: ent.ID, + Key: "anotherkey", + ProjectID: proj.ID, + ProviderID: prov.ID, }) require.NoError(t, err) anotherValue, err := PropValueFromDbV1(anotherKeyVal.Value) @@ -204,8 +216,7 @@ func Test_PropertyCrud(t *testing.T) { const testRepoName = "testorg/testrepo_getbyprops" const testArtifactName = "testorg/testartifact_getbyprops" - t.Log("Creating repository for GetTypedEntitiesByPropertyV1 test") - repo, err := testQueries.CreateEntity(context.Background(), CreateEntityParams{ + repo, err := testQueries.CreateEntity(ctx, CreateEntityParams{ EntityType: EntitiesRepository, Name: testRepoName, ProjectID: proj.ID, @@ -213,24 +224,22 @@ func Test_PropertyCrud(t *testing.T) { OriginatedFrom: uuid.NullUUID{}, }) require.NoError(t, err) - require.NotEmpty(t, repo) - _, err = testQueries.UpsertPropertyValueV1(context.Background(), UpsertPropertyValueV1Params{ + _, err = testQueries.UpsertPropertyValueV1(ctx, UpsertPropertyValueV1Params{ EntityID: repo.ID, Key: "sharedkey", Value: "sharedvalue", }) require.NoError(t, err) - _, err = testQueries.UpsertPropertyValueV1(context.Background(), UpsertPropertyValueV1Params{ + _, err = testQueries.UpsertPropertyValueV1(ctx, UpsertPropertyValueV1Params{ EntityID: repo.ID, Key: "repokey", Value: "repovalue", }) require.NoError(t, err) - t.Log("Creating artifact for GetTypedEntitiesByPropertyV1 test") - art, err := testQueries.CreateEntity(context.Background(), CreateEntityParams{ + art, err := testQueries.CreateEntity(ctx, CreateEntityParams{ EntityType: EntitiesArtifact, Name: testArtifactName, ProjectID: proj.ID, @@ -238,58 +247,56 @@ func Test_PropertyCrud(t *testing.T) { OriginatedFrom: uuid.NullUUID{}, }) require.NoError(t, err) - require.NotEmpty(t, art) - _, err = testQueries.UpsertPropertyValueV1(context.Background(), UpsertPropertyValueV1Params{ + _, err = testQueries.UpsertPropertyValueV1(ctx, UpsertPropertyValueV1Params{ EntityID: art.ID, Key: "sharedkey", Value: "sharedvalue", }) require.NoError(t, err) - t.Log("Get by shared key and repo should return the repository") getEnt, err := testQueries.GetTypedEntitiesByPropertyV1( - context.Background(), EntitiesRepository, "sharedkey", "sharedvalue", + ctx, EntitiesRepository, "sharedkey", "sharedvalue", GetTypedEntitiesOptions{ - ProjectID: proj.ID, + ProjectID: proj.ID, + ProviderID: prov.ID, }) require.NoError(t, err) require.Len(t, getEnt, 1) require.Equal(t, getEnt[0].ID, repo.ID) - t.Log("Get by shared key and artifact should return the artifact") getEnt, err = testQueries.GetTypedEntitiesByPropertyV1( - context.Background(), EntitiesArtifact, "sharedkey", "sharedvalue", + ctx, EntitiesArtifact, "sharedkey", "sharedvalue", GetTypedEntitiesOptions{ - ProjectID: proj.ID, + ProjectID: proj.ID, + ProviderID: prov.ID, }) require.NoError(t, err) require.Len(t, getEnt, 1) require.Equal(t, getEnt[0].ID, art.ID) - t.Log("Get by repo key and value should return the repository") getEnt, err = testQueries.GetTypedEntitiesByPropertyV1( - context.Background(), EntitiesRepository, "repokey", "repovalue", + ctx, EntitiesRepository, "repokey", "repovalue", GetTypedEntitiesOptions{ - ProjectID: proj.ID, + ProjectID: proj.ID, + ProviderID: prov.ID, }) require.NoError(t, err) require.Len(t, getEnt, 1) require.Equal(t, getEnt[0].ID, repo.ID) - t.Log("Get by repo key, value and provider should return the repository") getEnt, err = testQueries.GetTypedEntitiesByPropertyV1( - context.Background(), EntitiesRepository, "repokey", "repovalue", + ctx, EntitiesRepository, "repokey", "repovalue", GetTypedEntitiesOptions{ + ProjectID: proj.ID, ProviderID: prov.ID, }) require.NoError(t, err) require.Len(t, getEnt, 1) require.Equal(t, getEnt[0].ID, repo.ID) - t.Log("Get by repo key, value, project and provider should return the repository") getEnt, err = testQueries.GetTypedEntitiesByPropertyV1( - context.Background(), EntitiesRepository, "repokey", "repovalue", + ctx, EntitiesRepository, "repokey", "repovalue", GetTypedEntitiesOptions{ ProjectID: proj.ID, ProviderID: prov.ID, @@ -298,10 +305,10 @@ func Test_PropertyCrud(t *testing.T) { require.Len(t, getEnt, 1) require.Equal(t, getEnt[0].ID, repo.ID) - t.Log("Getting by key but with wrong provider should return nothing") getEnt, err = testQueries.GetTypedEntitiesByPropertyV1( - context.Background(), EntitiesRepository, "repokey", "repovalue", + ctx, EntitiesRepository, "repokey", "repovalue", GetTypedEntitiesOptions{ + ProjectID: proj.ID, ProviderID: uuid.New(), }) require.NoError(t, err) @@ -361,8 +368,9 @@ func Test_GetEntitiesByHierarchy(t *testing.T) { proj4 := createRandomProject(t, proj3.ID) prov := createRandomProvider(t, proj.ID) + ctx := context.Background() - entRepo, err := testQueries.CreateEntity(context.Background(), CreateEntityParams{ + entRepo, err := testQueries.CreateEntity(ctx, CreateEntityParams{ EntityType: EntitiesRepository, Name: rand.RandomName(seed), ProjectID: proj.ID, @@ -372,7 +380,7 @@ func Test_GetEntitiesByHierarchy(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, entRepo) - entPkg, err := testQueries.CreateEntity(context.Background(), CreateEntityParams{ + entPkg, err := testQueries.CreateEntity(ctx, CreateEntityParams{ EntityType: EntitiesRepository, Name: rand.RandomName(seed), ProjectID: proj2.ID, @@ -382,7 +390,7 @@ func Test_GetEntitiesByHierarchy(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, entPkg) - entPr, err := testQueries.CreateEntity(context.Background(), CreateEntityParams{ + entPr, err := testQueries.CreateEntity(ctx, CreateEntityParams{ EntityType: EntitiesRepository, Name: rand.RandomName(seed), ProjectID: proj3.ID, @@ -423,7 +431,7 @@ func Test_GetEntitiesByHierarchy(t *testing.T) { t.Run(scenario.name, func(t *testing.T) { t.Parallel() - ents, err := testQueries.GetEntitiesByProjectHierarchy(context.Background(), scenario.projects) + ents, err := testQueries.GetEntitiesByProjectHierarchy(ctx, scenario.projects) require.NoError(t, err) require.Len(t, ents, scenario.expectedNumEnts) }) diff --git a/internal/db/entity_execution_lock.sql.go b/internal/db/entity_execution_lock.sql.go index 5d4c2e71bc..64a45e84cf 100644 --- a/internal/db/entity_execution_lock.sql.go +++ b/internal/db/entity_execution_lock.sql.go @@ -7,6 +7,7 @@ package db import ( "context" + "time" "github.com/google/uuid" ) @@ -46,12 +47,24 @@ func (q *Queries) EnqueueFlush(ctx context.Context, arg EnqueueFlushParams) (Flu const flushCache = `-- name: FlushCache :one DELETE FROM flush_cache -WHERE entity_instance_id= $1 +WHERE flush_cache.entity_instance_id = $1 + AND flush_cache.project_id = $2 + AND EXISTS ( + SELECT 1 FROM entity_instances ei + WHERE ei.id = flush_cache.entity_instance_id + AND ei.provider_id = $3 + ) RETURNING id, entity, queued_at, project_id, entity_instance_id ` -func (q *Queries) FlushCache(ctx context.Context, entityInstanceID uuid.UUID) (FlushCache, error) { - row := q.db.QueryRowContext(ctx, flushCache, entityInstanceID) +type FlushCacheParams struct { + EntityInstanceID uuid.UUID `json:"entity_instance_id"` + ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` +} + +func (q *Queries) FlushCache(ctx context.Context, arg FlushCacheParams) (FlushCache, error) { + row := q.db.QueryRowContext(ctx, flushCache, arg.EntityInstanceID, arg.ProjectID, arg.ProviderID) var i FlushCache err := row.Scan( &i.ID, @@ -64,24 +77,39 @@ func (q *Queries) FlushCache(ctx context.Context, entityInstanceID uuid.UUID) (F } const listFlushCache = `-- name: ListFlushCache :many -SELECT id, entity, queued_at, project_id, entity_instance_id FROM flush_cache +SELECT + fc.entity, + fc.project_id, + fc.entity_instance_id, + fc.queued_at, + ei.provider_id +FROM flush_cache fc +JOIN entity_instances ei ON fc.entity_instance_id = ei.id ` -func (q *Queries) ListFlushCache(ctx context.Context) ([]FlushCache, error) { +type ListFlushCacheRow struct { + Entity Entities `json:"entity"` + ProjectID uuid.UUID `json:"project_id"` + EntityInstanceID uuid.UUID `json:"entity_instance_id"` + QueuedAt time.Time `json:"queued_at"` + ProviderID uuid.UUID `json:"provider_id"` +} + +func (q *Queries) ListFlushCache(ctx context.Context) ([]ListFlushCacheRow, error) { rows, err := q.db.QueryContext(ctx, listFlushCache) if err != nil { return nil, err } defer rows.Close() - items := []FlushCache{} + items := []ListFlushCacheRow{} for rows.Next() { - var i FlushCache + var i ListFlushCacheRow if err := rows.Scan( - &i.ID, &i.Entity, - &i.QueuedAt, &i.ProjectID, &i.EntityInstanceID, + &i.QueuedAt, + &i.ProviderID, ); err != nil { return nil, err } diff --git a/internal/db/eval_history.sql.go b/internal/db/eval_history.sql.go index 687eb4ef8c..a5ae80a325 100644 --- a/internal/db/eval_history.sql.go +++ b/internal/db/eval_history.sql.go @@ -271,6 +271,7 @@ SELECT s.id::uuid AS evaluation_id, -- entity id ere.entity_instance_id as entity_id, j.id as project_id, + ei.provider_id, -- rule type, name, and profile rt.name AS rule_type, ri.name AS rule_name, @@ -362,6 +363,7 @@ type ListEvaluationHistoryRow struct { EntityType Entities `json:"entity_type"` EntityID uuid.UUID `json:"entity_id"` ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` RuleType string `json:"rule_type"` RuleName string `json:"rule_name"` RuleSeverity Severity `json:"rule_severity"` @@ -413,6 +415,7 @@ func (q *Queries) ListEvaluationHistory(ctx context.Context, arg ListEvaluationH &i.EntityType, &i.EntityID, &i.ProjectID, + &i.ProviderID, &i.RuleType, &i.RuleName, &i.RuleSeverity, diff --git a/internal/db/profile_selector_scan.go b/internal/db/profile_selector_scan.go index cb35c3ca0f..fad9d151de 100644 --- a/internal/db/profile_selector_scan.go +++ b/internal/db/profile_selector_scan.go @@ -15,7 +15,7 @@ func (s *ProfileSelector) Scan(value interface{}) error { return nil } - // Convert the value to a string + // Convert the value to string bytes, ok := value.([]byte) if !ok { return fmt.Errorf("failed to scan SelectorInfo: %v", value) diff --git a/internal/db/profile_status.sql.go b/internal/db/profile_status.sql.go index ce5368dfcc..0960e5688e 100644 --- a/internal/db/profile_status.sql.go +++ b/internal/db/profile_status.sql.go @@ -252,6 +252,7 @@ SELECT ere.entity_instance_id as entity_id, ei.name as entity_name, ei.project_id as project_id, + ei.provider_id, rt.release_phase as rule_type_release_phase, eo.output AS eval_output FROM latest_evaluation_statuses les @@ -304,6 +305,7 @@ type ListRuleEvaluationsByProfileIdRow struct { EntityID uuid.UUID `json:"entity_id"` EntityName string `json:"entity_name"` ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` RuleTypeReleasePhase ReleaseStatus `json:"rule_type_release_phase"` EvalOutput pqtype.NullRawMessage `json:"eval_output"` } @@ -348,6 +350,7 @@ func (q *Queries) ListRuleEvaluationsByProfileId(ctx context.Context, arg ListRu &i.EntityID, &i.EntityName, &i.ProjectID, + &i.ProviderID, &i.RuleTypeReleasePhase, &i.EvalOutput, ); err != nil { diff --git a/internal/db/profiles_test.go b/internal/db/profiles_test.go index 5f3aabaad4..0fa041ca9a 100644 --- a/internal/db/profiles_test.go +++ b/internal/db/profiles_test.go @@ -1,3 +1,6 @@ +// SPDX-FileCopyrightText: Copyright 2023 The Minder Authors +// SPDX-License-Identifier: Apache-2.0 + package db import ( @@ -3418,8 +3421,9 @@ func TestCreateProfileStatusStoredDeleteProcedure(t *testing.T) { expectedStatusAfterSetup: EvalStatusTypesFailure, ruleStatusDeleteFn: func(delRepo *EntityRepository) { err := testQueries.DeleteEntity(context.Background(), DeleteEntityParams{ - ID: delRepo.ID, - ProjectID: delRepo.ProjectID, + ID: delRepo.ID, + ProjectID: delRepo.ProjectID, + ProviderID: delRepo.ProviderID, }) require.NoError(t, err) }, @@ -3471,8 +3475,9 @@ func TestCreateProfileStatusStoredDeleteProcedure(t *testing.T) { expectedStatusAfterSetup: EvalStatusTypesError, ruleStatusDeleteFn: func(delRepo *EntityRepository) { err := testQueries.DeleteEntity(context.Background(), DeleteEntityParams{ - ID: delRepo.ID, - ProjectID: delRepo.ProjectID, + ID: delRepo.ID, + ProjectID: delRepo.ProjectID, + ProviderID: delRepo.ProviderID, }) require.NoError(t, err) }, @@ -3524,8 +3529,9 @@ func TestCreateProfileStatusStoredDeleteProcedure(t *testing.T) { expectedStatusAfterSetup: EvalStatusTypesError, ruleStatusDeleteFn: func(delRepo *EntityRepository) { err := testQueries.DeleteEntity(context.Background(), DeleteEntityParams{ - ID: delRepo.ID, - ProjectID: delRepo.ProjectID, + ID: delRepo.ID, + ProjectID: delRepo.ProjectID, + ProviderID: delRepo.ProviderID, }) require.NoError(t, err) }, @@ -3558,8 +3564,9 @@ func TestCreateProfileStatusStoredDeleteProcedure(t *testing.T) { expectedStatusAfterSetup: EvalStatusTypesFailure, ruleStatusDeleteFn: func(delRepo *EntityRepository) { err := testQueries.DeleteEntity(context.Background(), DeleteEntityParams{ - ID: delRepo.ID, - ProjectID: delRepo.ProjectID, + ID: delRepo.ID, + ProjectID: delRepo.ProjectID, + ProviderID: delRepo.ProviderID, }) require.NoError(t, err) }, diff --git a/internal/db/projects_test.go b/internal/db/projects_test.go index a1183fd9d7..1c42fc98c1 100644 --- a/internal/db/projects_test.go +++ b/internal/db/projects_test.go @@ -137,6 +137,7 @@ func TestCreateDirectoryWithParentThatDoesntExist(t *testing.T) { _, err := testQueries.CreateProject(context.Background(), CreateProjectParams{ Name: t.Name(), ParentID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + Metadata: json.RawMessage("{}"), }) assert.Error(t, err, "should have errored") diff --git a/internal/db/querier.go b/internal/db/querier.go index 5c6286796f..7e25307b55 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -20,7 +20,7 @@ type Querier interface { // AddRuleTypeDataSourceReference(ctx context.Context, arg AddRuleTypeDataSourceReferenceParams) (RuleTypeDataSource, error) BulkGetProfilesByID(ctx context.Context, profileIds []uuid.UUID) ([]BulkGetProfilesByIDRow, error) - // CountEntitiesByType counts all entities of a given type (across all projects/providers). + // CountEntitiesByType counts all entities of a given type (Global admin metric). CountEntitiesByType(ctx context.Context, entityType Entities) (int64, error) // CountEntitiesByTypeAndProject counts entities of a given type for a specific project. CountEntitiesByTypeAndProject(ctx context.Context, arg CountEntitiesByTypeAndProjectParams) (int64, error) @@ -54,13 +54,14 @@ type Querier interface { // Subscriptions -- CreateSubscription(ctx context.Context, arg CreateSubscriptionParams) (Subscription, error) CreateUser(ctx context.Context, identitySubject string) (User, error) - DeleteAllPropertiesForEntity(ctx context.Context, entityID uuid.UUID) error + // DeleteAllPropertiesForEntity deletes all properties for an entity securely. + DeleteAllPropertiesForEntity(ctx context.Context, arg DeleteAllPropertiesForEntityParams) error DeleteDataSource(ctx context.Context, arg DeleteDataSourceParams) (DataSource, error) DeleteDataSourceFunction(ctx context.Context, arg DeleteDataSourceFunctionParams) (DataSourcesFunction, error) // DeleteDataSourceFunctions deletes all functions associated with a given datasource // in a specific project. DeleteDataSourceFunctions(ctx context.Context, arg DeleteDataSourceFunctionsParams) ([]DataSourcesFunction, error) - // DeleteEntity removes an entity from the entity_instances table for a project. + // DeleteEntity removes an entity from the entity_instances table securely. DeleteEntity(ctx context.Context, arg DeleteEntityParams) error DeleteEvaluationHistoryByIDs(ctx context.Context, evaluationids []uuid.UUID) (int64, error) DeleteEvaluationOutputsByEvaluationIDs(ctx context.Context, evaluationids []uuid.UUID) (int64, error) @@ -74,6 +75,7 @@ type Querier interface { DeleteProfile(ctx context.Context, arg DeleteProfileParams) error DeleteProfileForEntity(ctx context.Context, arg DeleteProfileForEntityParams) error DeleteProject(ctx context.Context, id uuid.UUID) ([]DeleteProjectRow, error) + // DeleteProperty deletes a property, using USING to ensure the caller owns the parent entity. DeleteProperty(ctx context.Context, arg DeletePropertyParams) error DeleteProvider(ctx context.Context, arg DeleteProviderParams) error DeleteRuleInstanceOfProfileInProject(ctx context.Context, arg DeleteRuleInstanceOfProfileInProjectParams) error @@ -84,18 +86,19 @@ type Querier interface { DeleteSessionStateByProjectID(ctx context.Context, arg DeleteSessionStateByProjectIDParams) error DeleteUser(ctx context.Context, id int32) error EnqueueFlush(ctx context.Context, arg EnqueueFlushParams) (FlushCache, error) - // EntityExistsAfterID checks if any entity of a given type exists after a cursor ID. + // EntityExistsAfterID checks if any entity exists after a cursor ID securely. EntityExistsAfterID(ctx context.Context, arg EntityExistsAfterIDParams) (bool, error) // FindProviders allows us to take a trait and filter // providers by it. It also optionally takes a name, in case we want to // filter by name as well. FindProviders(ctx context.Context, arg FindProvidersParams) ([]Provider, error) - FlushCache(ctx context.Context, entityInstanceID uuid.UUID) (FlushCache, error) + FlushCache(ctx context.Context, arg FlushCacheParams) (FlushCache, error) GetAccessTokenByEnrollmentNonce(ctx context.Context, arg GetAccessTokenByEnrollmentNonceParams) (ProviderAccessToken, error) GetAccessTokenByProjectID(ctx context.Context, arg GetAccessTokenByProjectIDParams) (ProviderAccessToken, error) GetAccessTokenByProvider(ctx context.Context, provider string) ([]ProviderAccessToken, error) GetAccessTokenSinceDate(ctx context.Context, arg GetAccessTokenSinceDateParams) (ProviderAccessToken, error) - GetAllPropertiesForEntity(ctx context.Context, entityID uuid.UUID) ([]Property, error) + // GetAllPropertiesForEntity retrieves all properties for one entity, strictly bounded. + GetAllPropertiesForEntity(ctx context.Context, arg GetAllPropertiesForEntityParams) ([]Property, error) GetBundle(ctx context.Context, arg GetBundleParams) (Bundle, error) GetChildrenProjects(ctx context.Context, id uuid.UUID) ([]GetChildrenProjectsRow, error) // GetDataSource retrieves a datasource by its id and a project hierarchy. @@ -111,16 +114,16 @@ type Querier interface { GetDataSourceByName(ctx context.Context, arg GetDataSourceByNameParams) (DataSource, error) // GetEntitiesByProjectHierarchy retrieves all entities for a project or hierarchy of projects. GetEntitiesByProjectHierarchy(ctx context.Context, projects []uuid.UUID) ([]EntityInstance, error) - // GetEntitiesByProvider retrieves all entities of a given provider. + // GetEntitiesByProvider retrieves all entities of a given provider scoped by project hierarchy. // this is how one would get all repositories, artifacts, etc. for a given provider. - GetEntitiesByProvider(ctx context.Context, providerID uuid.UUID) ([]EntityInstance, error) - // GetEntitiesByType retrieves all entities of a given type for a project or hierarchy of projects. + GetEntitiesByProvider(ctx context.Context, arg GetEntitiesByProviderParams) ([]EntityInstance, error) + // GetEntitiesByType retrieves all entities of a given type for a project hierarchy. // this is how one would get all repositories, artifacts, etc. GetEntitiesByType(ctx context.Context, arg GetEntitiesByTypeParams) ([]EntityInstance, error) GetEntitlementFeaturesByProjectID(ctx context.Context, projectID uuid.UUID) ([]string, error) // GetEntityByID retrieves an entity by its ID for a project or hierarchy of projects. - GetEntityByID(ctx context.Context, id uuid.UUID) (EntityInstance, error) - // GetEntityByName retrieves an entity by its name for a project or hierarchy of projects. + GetEntityByID(ctx context.Context, arg GetEntityByIDParams) (EntityInstance, error) + // GetEntityByName retrieves an entity by its name securely. GetEntityByName(ctx context.Context, arg GetEntityByNameParams) (EntityInstance, error) GetEvaluationHistory(ctx context.Context, arg GetEvaluationHistoryParams) (GetEvaluationHistoryRow, error) GetEvaluationOutput(ctx context.Context, id uuid.UUID) (EvaluationOutput, error) @@ -163,6 +166,9 @@ type Querier interface { GetProjectByID(ctx context.Context, id uuid.UUID) (Project, error) GetProjectByName(ctx context.Context, name string) (Project, error) GetProjectIDBySessionState(ctx context.Context, sessionState string) (GetProjectIDBySessionStateRow, error) + // GetPropertiesForEntities retrieves properties for multiple entities in bulk + GetPropertiesForEntities(ctx context.Context, arg GetPropertiesForEntitiesParams) ([]Property, error) + // GetProperty retrieves a single property, using a JOIN to ensure the caller owns the parent entity GetProperty(ctx context.Context, arg GetPropertyParams) (Property, error) GetProviderByID(ctx context.Context, id uuid.UUID) (Provider, error) GetProviderByIDAndProject(ctx context.Context, arg GetProviderByIDAndProjectParams) (Provider, error) @@ -186,6 +192,7 @@ type Querier interface { GetSelectorByID(ctx context.Context, id uuid.UUID) (ProfileSelector, error) GetSelectorsByProfileID(ctx context.Context, profileID uuid.UUID) ([]ProfileSelector, error) GetSubscriptionByProjectBundle(ctx context.Context, arg GetSubscriptionByProjectBundleParams) (Subscription, error) + // GetTypedEntitiesByProperty retrieves entities matching a specific property JSONB query securely. GetTypedEntitiesByProperty(ctx context.Context, arg GetTypedEntitiesByPropertyParams) ([]EntityInstance, error) GetUnclaimedInstallationsByUser(ctx context.Context, ghID sql.NullString) ([]ProviderGithubAppInstallation, error) GetUserByID(ctx context.Context, id int32) (User, error) @@ -204,12 +211,12 @@ type Querier interface { // Note that to get a datasource for a given project, one can simply // pass one project id in the project_id array. ListDataSources(ctx context.Context, projects []uuid.UUID) ([]DataSource, error) - // ListEntitiesAfterID retrieves entities of a given type after a cursor ID, for pagination. + // ListEntitiesAfterID retrieves entities for pagination securely. // This is used for cursor-based iteration over all entities (e.g., in the reminder service). ListEntitiesAfterID(ctx context.Context, arg ListEntitiesAfterIDParams) ([]EntityInstance, error) ListEvaluationHistory(ctx context.Context, arg ListEvaluationHistoryParams) ([]ListEvaluationHistoryRow, error) ListEvaluationHistoryStaleRecords(ctx context.Context, arg ListEvaluationHistoryStaleRecordsParams) ([]ListEvaluationHistoryStaleRecordsRow, error) - ListFlushCache(ctx context.Context) ([]FlushCache, error) + ListFlushCache(ctx context.Context) ([]ListFlushCacheRow, error) // ListInvitationsForProject collects the information visible to project // administrators after an invitation has been issued. In particular, it // *does not* report the invitation code, which is a secret intended for @@ -286,6 +293,8 @@ type Querier interface { UpsertInstallationID(ctx context.Context, arg UpsertInstallationIDParams) (ProviderGithubAppInstallation, error) UpsertLatestEvaluationStatus(ctx context.Context, arg UpsertLatestEvaluationStatusParams) error UpsertProfileForEntity(ctx context.Context, arg UpsertProfileForEntityParams) (EntityProfile, error) + // UpsertProperty upserts a property. + // NOTE: Ownership MUST be verified in Go (e.g. via GetEntityByID) before executing this statement. UpsertProperty(ctx context.Context, arg UpsertPropertyParams) (Property, error) // SPDX-FileCopyrightText: Copyright 2024 The Minder Authors // SPDX-License-Identifier: Apache-2.0 diff --git a/internal/eea/eea.go b/internal/eea/eea.go index 6f2d0fa55b..958d09d2f9 100644 --- a/internal/eea/eea.go +++ b/internal/eea/eea.go @@ -107,7 +107,11 @@ func (e *EEA) aggregate(msg *message.Message) (*message.Message, error) { qtx := e.querier.GetQuerierWithTransaction(tx) // We'll only attempt to lock if the entity exists. - _, err = qtx.GetEntityByID(ctx, entityID) + _, err = qtx.GetEntityByID(ctx, db.GetEntityByIDParams{ + ID: entityID, + ProjectID: projectID, + ProviderID: inf.ProviderID, + }) if err != nil { // explicit rollback if entity had an issue. _ = e.querier.Rollback(tx) @@ -187,7 +191,11 @@ func (e *EEA) FlushMessageHandler(msg *message.Message) error { logger.Debug().Msg("flushing event") - _, err = e.querier.FlushCache(ctx, eID) + _, err = e.querier.FlushCache(ctx, db.FlushCacheParams{ + EntityInstanceID: eID, + ProjectID: inf.ProjectID, + ProviderID: inf.ProviderID, + }) // Nothing to do here. If we can't flush the cache, it means // that the event has already been executed. if err != nil && errors.Is(err, sql.ErrNoRows) { @@ -218,7 +226,7 @@ func (e *EEA) FlushAll(ctx context.Context) error { for _, cache := range caches { eiw, err := e.buildEntityWrapper(ctx, cache.Entity, - cache.ProjectID, cache.EntityInstanceID) + cache.ProjectID, cache.EntityInstanceID, cache.ProviderID) if err != nil { if errors.Is(err, sql.ErrNoRows) || errors.Is(err, service.ErrEntityNotFound) { continue @@ -246,14 +254,15 @@ func (e *EEA) buildEntityWrapper( entity db.Entities, projID uuid.UUID, entityID uuid.UUID, + providerID uuid.UUID, ) (*entities.EntityInfoWrapper, error) { switch entity { case db.EntitiesRepository: - return e.buildRepositoryInfoWrapper(ctx, entityID, projID) + return e.buildRepositoryInfoWrapper(ctx, entityID, projID, providerID) case db.EntitiesArtifact: - return e.buildArtifactInfoWrapper(ctx, entityID, projID) + return e.buildArtifactInfoWrapper(ctx, entityID, projID, providerID) case db.EntitiesPullRequest: - return e.buildPullRequestInfoWrapper(ctx, entityID, projID) + return e.buildPullRequestInfoWrapper(ctx, entityID, projID, providerID) case db.EntitiesBuildEnvironment, db.EntitiesRelease, db.EntitiesPipelineRun, db.EntitiesTaskRun, db.EntitiesBuild: return nil, fmt.Errorf("entity type %q not yet supported", entity) @@ -261,13 +270,13 @@ func (e *EEA) buildEntityWrapper( return nil, fmt.Errorf("unknown entity type: %q", entity) } } - func (e *EEA) buildRepositoryInfoWrapper( ctx context.Context, repoID uuid.UUID, projID uuid.UUID, + providerID uuid.UUID, ) (*entities.EntityInfoWrapper, error) { - ent, err := e.entityFetcher.EntityWithPropertiesByID(ctx, repoID, nil) + ent, err := e.entityFetcher.EntityWithPropertiesByID(ctx, repoID, projID, providerID, nil) if err != nil { return nil, fmt.Errorf("error fetching entity: %w", err) } @@ -297,8 +306,9 @@ func (e *EEA) buildArtifactInfoWrapper( ctx context.Context, artID uuid.UUID, projID uuid.UUID, + providerID uuid.UUID, ) (*entities.EntityInfoWrapper, error) { - ent, err := e.entityFetcher.EntityWithPropertiesByID(ctx, artID, nil) + ent, err := e.entityFetcher.EntityWithPropertiesByID(ctx, artID, projID, providerID, nil) if err != nil { return nil, fmt.Errorf("error fetching entity: %w", err) } @@ -329,8 +339,9 @@ func (e *EEA) buildPullRequestInfoWrapper( ctx context.Context, prID uuid.UUID, projID uuid.UUID, + providerID uuid.UUID, ) (*entities.EntityInfoWrapper, error) { - ent, err := e.entityFetcher.EntityWithPropertiesByID(ctx, prID, nil) + ent, err := e.entityFetcher.EntityWithPropertiesByID(ctx, prID, projID, providerID, nil) if err != nil { return nil, fmt.Errorf("error fetching entity: %w", err) } diff --git a/internal/eea/eea_test.go b/internal/eea/eea_test.go index a9c3b17166..fef49fcf58 100644 --- a/internal/eea/eea_test.go +++ b/internal/eea/eea_test.go @@ -39,10 +39,6 @@ const ( providerName = "test-provider" ) -var ( - providerID = uuid.New() -) - func TestAggregator(t *testing.T) { t.Parallel() @@ -55,7 +51,7 @@ func TestAggregator(t *testing.T) { var concurrentEvents int64 = 100 - projectID, repoID := createNeededEntities(ctx, t, testQueries) + projectID, dbProviderID, repoID := createNeededEntities(ctx, t, testQueries) evt, err := eventer.New(ctx, nil, &serverconfig.EventConfig{ Driver: "go-channel", @@ -101,7 +97,7 @@ func TestAggregator(t *testing.T) { WithRepository(&minderv1.Repository{}). WithID(repoID). WithProjectID(projectID). - WithProviderID(providerID) + WithProviderID(dbProviderID) msg, err := inf.BuildMessage() require.NoError(t, err, "expected no error when building message") @@ -147,7 +143,7 @@ func TestAggregator(t *testing.T) { assert.Equal(t, int32(1), flushedMessages.count.Load(), "expected only one message to be published") } -func createNeededEntities(ctx context.Context, t *testing.T, testQueries db.Store) (projID uuid.UUID, repoID uuid.UUID) { +func createNeededEntities(ctx context.Context, t *testing.T, testQueries db.Store) (projID uuid.UUID, provID uuid.UUID, repoID uuid.UUID) { t.Helper() // setup project @@ -177,7 +173,7 @@ func createNeededEntities(ctx context.Context, t *testing.T, testQueries db.Stor }) require.NoError(t, err, "expected no error when creating repo") - return proj.ID, repo.ID + return proj.ID, prov.ID, repo.ID } func TestFlushAll(t *testing.T) { @@ -198,13 +194,13 @@ func TestFlushAll(t *testing.T) { name: "flushes one repo", mockDBSetup: func(ctx context.Context, mockStore *mockdb.MockStore) { mockStore.EXPECT().ListFlushCache(ctx). - Return([]db.FlushCache{ + Return([]db.ListFlushCacheRow{ { - ID: uuid.New(), Entity: db.EntitiesRepository, QueuedAt: time.Now(), ProjectID: projectID, EntityInstanceID: repoID, + ProviderID: providerID, }, }, nil) @@ -212,7 +208,7 @@ func TestFlushAll(t *testing.T) { mockStore.EXPECT().FlushCache(ctx, gomock.Any()).Times(1) }, mockPropSvcSetup: func(mockPropSvc *propsvcmock.MockPropertiesService) { - mockPropSvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(repoID), gomock.Nil()). + mockPropSvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(repoID), gomock.Eq(projectID), gomock.Eq(providerID), gomock.Nil()). Return(&models.EntityWithProperties{ Entity: models.EntityInstance{ ID: repoID, @@ -231,13 +227,13 @@ func TestFlushAll(t *testing.T) { name: "flushes one artifact with repo", mockDBSetup: func(ctx context.Context, mockStore *mockdb.MockStore) { mockStore.EXPECT().ListFlushCache(ctx). - Return([]db.FlushCache{ + Return([]db.ListFlushCacheRow{ { - ID: uuid.New(), Entity: db.EntitiesArtifact, QueuedAt: time.Now(), ProjectID: projectID, EntityInstanceID: artID, + ProviderID: providerID, }, }, nil) @@ -245,7 +241,7 @@ func TestFlushAll(t *testing.T) { mockStore.EXPECT().FlushCache(ctx, gomock.Any()).Times(1) }, mockPropSvcSetup: func(mockPropSvc *propsvcmock.MockPropertiesService) { - mockPropSvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(artID), gomock.Nil()). + mockPropSvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(artID), gomock.Eq(projectID), gomock.Eq(providerID), gomock.Nil()). Return(&models.EntityWithProperties{ Entity: models.EntityInstance{ ID: artID, @@ -265,13 +261,13 @@ func TestFlushAll(t *testing.T) { name: "flushes one artifact with no repo", mockDBSetup: func(ctx context.Context, mockStore *mockdb.MockStore) { mockStore.EXPECT().ListFlushCache(ctx). - Return([]db.FlushCache{ + Return([]db.ListFlushCacheRow{ { - ID: uuid.New(), Entity: db.EntitiesArtifact, QueuedAt: time.Now(), ProjectID: projectID, EntityInstanceID: artID, + ProviderID: providerID, }, }, nil) @@ -279,7 +275,7 @@ func TestFlushAll(t *testing.T) { mockStore.EXPECT().FlushCache(ctx, gomock.Any()).Times(1) }, mockPropSvcSetup: func(mockPropSvc *propsvcmock.MockPropertiesService) { - mockPropSvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(artID), gomock.Nil()). + mockPropSvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(artID), gomock.Eq(projectID), gomock.Eq(providerID), gomock.Nil()). Return(&models.EntityWithProperties{ Entity: models.EntityInstance{ ID: artID, @@ -298,13 +294,13 @@ func TestFlushAll(t *testing.T) { name: "flushes one PR", mockDBSetup: func(ctx context.Context, mockStore *mockdb.MockStore) { mockStore.EXPECT().ListFlushCache(ctx). - Return([]db.FlushCache{ + Return([]db.ListFlushCacheRow{ { - ID: uuid.New(), Entity: db.EntitiesPullRequest, ProjectID: projectID, EntityInstanceID: prID, QueuedAt: time.Now(), + ProviderID: providerID, }, }, nil) @@ -312,7 +308,7 @@ func TestFlushAll(t *testing.T) { mockStore.EXPECT().FlushCache(ctx, gomock.Any()).Times(1) }, mockPropSvcSetup: func(mockPropSvc *propsvcmock.MockPropertiesService) { - mockPropSvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(prID), gomock.Nil()). + mockPropSvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(prID), gomock.Eq(projectID), gomock.Eq(providerID), gomock.Nil()). Return(&models.EntityWithProperties{ Entity: models.EntityInstance{ ID: prID, @@ -507,20 +503,21 @@ func TestFlushAllListFlushListsARepoThatGetsDeletedLater(t *testing.T) { repoID := uuid.New() projID := uuid.New() + provID := uuid.New() // initial list flush mockStore.EXPECT().ListFlushCache(ctx). - Return([]db.FlushCache{ + Return([]db.ListFlushCacheRow{ { - ID: uuid.New(), Entity: db.EntitiesRepository, ProjectID: projID, EntityInstanceID: repoID, QueuedAt: time.Now(), + ProviderID: provID, }, }, nil) - propsvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(repoID), gomock.Nil()). + propsvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), gomock.Eq(repoID), gomock.Eq(projID), gomock.Eq(provID), gomock.Nil()). Return(nil, psvc.ErrEntityNotFound) t.Log("Flushing all") diff --git a/internal/engine/executor.go b/internal/engine/executor.go index 94caac4591..ea1b5cde9e 100644 --- a/internal/engine/executor.go +++ b/internal/engine/executor.go @@ -240,7 +240,7 @@ func (e *executor) profileEvalStatus( } // get the entity with properties by the entity UUID - ewp, err := e.propService.EntityWithPropertiesByID(ctx, entityID, + ewp, err := e.propService.EntityWithPropertiesByID(ctx, entityID, eiw.ProjectID, eiw.ProviderID, service.CallBuilder().WithStoreOrTransaction(e.querier)) if err != nil { return fmt.Errorf("error getting entity with properties: %w", err) diff --git a/internal/engine/executor_test.go b/internal/engine/executor_test.go index ee671ea848..7c907492df 100644 --- a/internal/engine/executor_test.go +++ b/internal/engine/executor_test.go @@ -317,7 +317,7 @@ default allow = true`, mockPropSvc := mockprops.NewMockPropertiesService(ctrl) mockPropSvc.EXPECT(). - EntityWithPropertiesByID(gomock.Any(), repositoryID, gomock.Any()). + EntityWithPropertiesByID(gomock.Any(), repositoryID, projectID, providerID, gomock.Any()). Return(&models.EntityWithProperties{ Entity: models.EntityInstance{ ID: repositoryID, diff --git a/internal/entities/handlers/handler_test.go b/internal/entities/handlers/handler_test.go index f80ffd0492..6c5f2f232d 100644 --- a/internal/entities/handlers/handler_test.go +++ b/internal/entities/handlers/handler_test.go @@ -210,7 +210,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { handlerBuilderFn: refreshByIDHandlerBuilder, messageBuilder: func() *message.HandleEntityAndDoMessage { return message.NewEntityRefreshAndDoMessage(). - WithEntityID(repoID) + WithEntityID(repoID). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { ewp := buildEwp(t, repoEwp, repoPropMap) @@ -235,7 +237,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { handlerBuilderFn: refreshByIDHandlerBuilder, messageBuilder: func() *message.HandleEntityAndDoMessage { return message.NewEntityRefreshAndDoMessage(). - WithEntityID(uuid.Nil) + WithEntityID(uuid.Nil). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { return fixtures.NewMockPropertiesService() @@ -255,7 +259,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { ewp := buildEwp(t, repoEwp, repoPropMap) @@ -290,7 +296,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). WithMatchProps(matchProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { ewp := buildEwp(t, repoEwp, repoPropMap) @@ -325,7 +333,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). WithMatchProps(matchProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { ewp := buildEwp(t, repoEwp, repoPropMap) @@ -350,7 +360,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { privateRepoMap := maps.Clone(repoPropMap) @@ -384,7 +396,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { privateRepoMap := maps.Clone(repoPropMap) @@ -413,7 +427,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { archivedRepoMap := maps.Clone(repoPropMap) @@ -441,7 +457,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { return fixtures.NewMockPropertiesService( @@ -463,7 +481,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { return fixtures.NewMockPropertiesService( @@ -486,7 +506,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { return fixtures.NewMockPropertiesService( @@ -594,7 +616,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_PULL_REQUESTS, prProps). WithOriginator(minderv1.Entity_ENTITY_REPOSITORIES, originatorProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { repoPropsEwp := buildEwp(t, repoEwp, repoPropMap) @@ -625,7 +649,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { repoPropsEwp := buildEwp(t, repoEwp, repoPropMap) @@ -653,7 +679,9 @@ func TestRefreshEntityAndDoHandler_HandleRefreshEntityAndEval(t *testing.T) { return message.NewEntityRefreshAndDoMessage(). WithEntity(minderv1.Entity_ENTITY_REPOSITORIES, getByProps). - WithProviderImplementsHint("github") + WithProviderImplementsHint("github"). + WithProjectID(projectID). + WithProviderID(providerID) }, setupPropSvcMocks: func() fixtures.MockPropertyServiceBuilder { return fixtures.NewMockPropertiesService( diff --git a/internal/entities/handlers/message/message.go b/internal/entities/handlers/message/message.go index 85ae015cb4..f49a4ef057 100644 --- a/internal/entities/handlers/message/message.go +++ b/internal/entities/handlers/message/message.go @@ -41,6 +41,8 @@ type HandleEntityAndDoMessage struct { // use-case is to include the hook ID in the MatchProps to match against // the entity's hook ID to avoid forwading the message to the wrong entity. MatchProps map[string]any `json:"match_props"` + ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` } // NewEntityRefreshAndDoMessage creates a new HandleEntityAndDoMessage struct. @@ -74,6 +76,18 @@ func (e *HandleEntityAndDoMessage) WithEntityID(entityID uuid.UUID) *HandleEntit return e } +// WithProjectID sets the project ID for the entity. +func (e *HandleEntityAndDoMessage) WithProjectID(projectID uuid.UUID) *HandleEntityAndDoMessage { + e.ProjectID = projectID + return e +} + +// WithProviderID sets the provider ID for the entity. +func (e *HandleEntityAndDoMessage) WithProviderID(providerID uuid.UUID) *HandleEntityAndDoMessage { + e.ProviderID = providerID + return e +} + // WithProviderImplementsHint sets the provider hint for the entity that will be used when looking up the entity. // to the provider implements hint func (e *HandleEntityAndDoMessage) WithProviderImplementsHint(providerHint string) *HandleEntityAndDoMessage { diff --git a/internal/entities/handlers/strategies/entity/refresh_by_id.go b/internal/entities/handlers/strategies/entity/refresh_by_id.go index a1c0f3dfcb..1988bd4ee4 100644 --- a/internal/entities/handlers/strategies/entity/refresh_by_id.go +++ b/internal/entities/handlers/strategies/entity/refresh_by_id.go @@ -46,7 +46,10 @@ func (r *refreshEntityByIDStrategy) GetEntity( getEnt, err := db.WithTransaction(r.store, func(t db.ExtendQuerier) (*models.EntityWithProperties, error) { ewp, err := r.propSvc.EntityWithPropertiesByID( - ctx, entMsg.Entity.EntityID, + ctx, + entMsg.Entity.EntityID, + entMsg.ProjectID, + entMsg.ProviderID, propertyService.CallBuilder().WithStoreOrTransaction(t)) if err != nil { return nil, fmt.Errorf("error getting entity: %w", err) diff --git a/internal/entities/properties/service/entitycache.go b/internal/entities/properties/service/entitycache.go index e6dd0a98d2..2d477ae971 100644 --- a/internal/entities/properties/service/entitycache.go +++ b/internal/entities/properties/service/entitycache.go @@ -47,7 +47,10 @@ func newPropertyServiceWithPersistentEntityCache( } func (ps *propertyServiceWithPersistentEntityCache) EntityWithPropertiesByID( - ctx context.Context, entityID uuid.UUID, + ctx context.Context, + entityID uuid.UUID, + projectID uuid.UUID, + providerID uuid.UUID, opts *CallOptions, ) (*models.EntityWithProperties, error) { // Check the cache first. @@ -56,7 +59,7 @@ func (ps *propertyServiceWithPersistentEntityCache) EntityWithPropertiesByID( } // If not in the cache, call the underlying service. - ent, err := ps.PropertiesService.EntityWithPropertiesByID(ctx, entityID, opts) + ent, err := ps.PropertiesService.EntityWithPropertiesByID(ctx, entityID, projectID, providerID, opts) if err != nil { return nil, err } diff --git a/internal/entities/properties/service/helpers.go b/internal/entities/properties/service/helpers.go index 8cb7a957e7..d1100dcc7b 100644 --- a/internal/entities/properties/service/helpers.go +++ b/internal/entities/properties/service/helpers.go @@ -24,6 +24,7 @@ import ( func (ps *propertiesService) retrieveAllPropertiesForEntity( ctx context.Context, provider provifv1.Provider, entID uuid.UUID, + projectID uuid.UUID, providerID uuid.UUID, lookupProperties *properties.Properties, entType minderv1.Entity, opts *ReadOptions, l zerolog.Logger, ) (*properties.Properties, error) { @@ -33,7 +34,11 @@ func (ps *propertiesService) retrieveAllPropertiesForEntity( if entID != uuid.Nil { // fetch properties from db var err error - dbProps, err = qtx.GetAllPropertiesForEntity(ctx, entID) + dbProps, err = qtx.GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{ + EntityID: entID, + ProjectID: projectID, + ProviderID: providerID, + }) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, err } @@ -68,7 +73,7 @@ func (ps *propertiesService) retrieveAllPropertiesForEntity( } // save updated properties to db, thus making sure that the updatedAt are bumped - err = ps.ReplaceAllProperties(ctx, entID, refreshedProps, opts.getPropertiesServiceCallOptions()) + err = ps.ReplaceAllProperties(ctx, entID, refreshedProps, projectID, providerID, opts.getPropertiesServiceCallOptions()) if err != nil { return nil, fmt.Errorf("failed to update properties: %w", err) } diff --git a/internal/entities/properties/service/mock/fixtures/fixtures.go b/internal/entities/properties/service/mock/fixtures/fixtures.go index 2afdad191f..3cafd5fd25 100644 --- a/internal/entities/properties/service/mock/fixtures/fixtures.go +++ b/internal/entities/properties/service/mock/fixtures/fixtures.go @@ -64,7 +64,8 @@ func WithSuccessfulEntityWithPropertiesByID( ewp *models.EntityWithProperties, ) MockPropertyServiceOption { return func(mockPropSvc *mockSvc.MockPropertiesService) { - mockPropSvc.EXPECT().EntityWithPropertiesByID(gomock.Any(), entityID, gomock.Any()). + mockPropSvc.EXPECT(). + EntityWithPropertiesByID(gomock.Any(), entityID, ewp.Entity.ProjectID, ewp.Entity.ProviderID, gomock.Any()). Return(ewp, nil) } } @@ -130,7 +131,7 @@ func WithFailedGetEntityWithPropertiesByID( ) MockPropertyServiceOption { return func(mockPropSvc *mockSvc.MockPropertiesService) { mockPropSvc.EXPECT(). - EntityWithPropertiesByID(gomock.Any(), gomock.Any(), gomock.Any()). + EntityWithPropertiesByID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, err) } } diff --git a/internal/entities/properties/service/mock/service.go b/internal/entities/properties/service/mock/service.go index 7d047db7a1..05eed317b9 100644 --- a/internal/entities/properties/service/mock/service.go +++ b/internal/entities/properties/service/mock/service.go @@ -64,18 +64,18 @@ func (mr *MockPropertiesServiceMockRecorder) EntityWithPropertiesAsProto(ctx, ew } // EntityWithPropertiesByID mocks base method. -func (m *MockPropertiesService) EntityWithPropertiesByID(ctx context.Context, entityID uuid.UUID, opts *service.CallOptions) (*models.EntityWithProperties, error) { +func (m *MockPropertiesService) EntityWithPropertiesByID(ctx context.Context, entityID, projectID, providerID uuid.UUID, opts *service.CallOptions) (*models.EntityWithProperties, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "EntityWithPropertiesByID", ctx, entityID, opts) + ret := m.ctrl.Call(m, "EntityWithPropertiesByID", ctx, entityID, projectID, providerID, opts) ret0, _ := ret[0].(*models.EntityWithProperties) ret1, _ := ret[1].(error) return ret0, ret1 } // EntityWithPropertiesByID indicates an expected call of EntityWithPropertiesByID. -func (mr *MockPropertiesServiceMockRecorder) EntityWithPropertiesByID(ctx, entityID, opts any) *gomock.Call { +func (mr *MockPropertiesServiceMockRecorder) EntityWithPropertiesByID(ctx, entityID, projectID, providerID, opts any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EntityWithPropertiesByID", reflect.TypeOf((*MockPropertiesService)(nil).EntityWithPropertiesByID), ctx, entityID, opts) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EntityWithPropertiesByID", reflect.TypeOf((*MockPropertiesService)(nil).EntityWithPropertiesByID), ctx, entityID, projectID, providerID, opts) } // EntityWithPropertiesByUpstreamHint mocks base method. @@ -94,31 +94,31 @@ func (mr *MockPropertiesServiceMockRecorder) EntityWithPropertiesByUpstreamHint( } // ReplaceAllProperties mocks base method. -func (m *MockPropertiesService) ReplaceAllProperties(ctx context.Context, entityID uuid.UUID, props *properties.Properties, opts *service.CallOptions) error { +func (m *MockPropertiesService) ReplaceAllProperties(ctx context.Context, entityID uuid.UUID, props *properties.Properties, projectID, providerID uuid.UUID, opts *service.CallOptions) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReplaceAllProperties", ctx, entityID, props, opts) + ret := m.ctrl.Call(m, "ReplaceAllProperties", ctx, entityID, props, projectID, providerID, opts) ret0, _ := ret[0].(error) return ret0 } // ReplaceAllProperties indicates an expected call of ReplaceAllProperties. -func (mr *MockPropertiesServiceMockRecorder) ReplaceAllProperties(ctx, entityID, props, opts any) *gomock.Call { +func (mr *MockPropertiesServiceMockRecorder) ReplaceAllProperties(ctx, entityID, props, projectID, providerID, opts any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceAllProperties", reflect.TypeOf((*MockPropertiesService)(nil).ReplaceAllProperties), ctx, entityID, props, opts) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceAllProperties", reflect.TypeOf((*MockPropertiesService)(nil).ReplaceAllProperties), ctx, entityID, props, projectID, providerID, opts) } // ReplaceProperty mocks base method. -func (m *MockPropertiesService) ReplaceProperty(ctx context.Context, entityID uuid.UUID, key string, prop *properties.Property, opts *service.CallOptions) error { +func (m *MockPropertiesService) ReplaceProperty(ctx context.Context, entityID uuid.UUID, key string, prop *properties.Property, projectID, providerID uuid.UUID, opts *service.CallOptions) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReplaceProperty", ctx, entityID, key, prop, opts) + ret := m.ctrl.Call(m, "ReplaceProperty", ctx, entityID, key, prop, projectID, providerID, opts) ret0, _ := ret[0].(error) return ret0 } // ReplaceProperty indicates an expected call of ReplaceProperty. -func (mr *MockPropertiesServiceMockRecorder) ReplaceProperty(ctx, entityID, key, prop, opts any) *gomock.Call { +func (mr *MockPropertiesServiceMockRecorder) ReplaceProperty(ctx, entityID, key, prop, projectID, providerID, opts any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceProperty", reflect.TypeOf((*MockPropertiesService)(nil).ReplaceProperty), ctx, entityID, key, prop, opts) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceProperty", reflect.TypeOf((*MockPropertiesService)(nil).ReplaceProperty), ctx, entityID, key, prop, projectID, providerID, opts) } // RetrieveAllProperties mocks base method. diff --git a/internal/entities/properties/service/service.go b/internal/entities/properties/service/service.go index 72de9042e7..ffe575c4a4 100644 --- a/internal/entities/properties/service/service.go +++ b/internal/entities/properties/service/service.go @@ -50,7 +50,7 @@ type PropertiesService interface { ) (protoreflect.ProtoMessage, error) // EntityWithPropertiesByID Fetches an Entity by ID and Project in order to refresh the properties EntityWithPropertiesByID( - ctx context.Context, entityID uuid.UUID, opts *CallOptions, + ctx context.Context, entityID uuid.UUID, projectID uuid.UUID, providerID uuid.UUID, opts *CallOptions, ) (*models.EntityWithProperties, error) // EntityWithPropertiesByUpstreamHint fetches an entity by upstream properties // and returns the entity with its properties. It is expected that the caller @@ -83,7 +83,12 @@ type PropertiesService interface { ) error // ReplaceAllProperties saves all properties for the given entity ReplaceAllProperties( - ctx context.Context, entityID uuid.UUID, props *properties.Properties, opts *CallOptions, + ctx context.Context, + entityID uuid.UUID, + props *properties.Properties, + projectID uuid.UUID, + providerID uuid.UUID, + opts *CallOptions, ) error // SaveAllProperties saves all properties for the given entity SaveAllProperties( @@ -91,7 +96,13 @@ type PropertiesService interface { ) error // ReplaceProperty saves a single property for the given entity ReplaceProperty( - ctx context.Context, entityID uuid.UUID, key string, prop *properties.Property, opts *CallOptions, + ctx context.Context, + entityID uuid.UUID, + key string, + prop *properties.Property, + projectID uuid.UUID, + providerID uuid.UUID, + opts *CallOptions, ) error } @@ -144,7 +155,7 @@ func (ps *propertiesService) RetrieveAllProperties( return nil, fmt.Errorf("failed to get entity ID: %w", err) } - return ps.retrieveAllPropertiesForEntity(ctx, provider, entID, lookupProperties, entType, opts, l) + return ps.retrieveAllPropertiesForEntity(ctx, provider, entID, projectId, providerID, lookupProperties, entType, opts, l) } func (ps *propertiesService) RetrieveAllPropertiesForEntity( @@ -164,7 +175,16 @@ func (ps *propertiesService) RetrieveAllPropertiesForEntity( return fmt.Errorf("error instantiating provider: %w", err) } - props, err := ps.retrieveAllPropertiesForEntity(ctx, propClient, efp.Entity.ID, efp.Properties, efp.Entity.Type, opts, l) + props, err := ps.retrieveAllPropertiesForEntity( + ctx, + propClient, + efp.Entity.ID, + efp.Entity.ProjectID, + efp.Entity.ProviderID, + efp.Properties, + efp.Entity.Type, + opts, + l) if err != nil { return fmt.Errorf("error fetching properties for entity: %w", err) } @@ -175,6 +195,7 @@ func (ps *propertiesService) RetrieveAllPropertiesForEntity( func (ps *propertiesService) ReplaceAllProperties( ctx context.Context, entityID uuid.UUID, props *properties.Properties, + projectID uuid.UUID, providerID uuid.UUID, opts *CallOptions, ) error { qtx := ps.getStoreOrTransaction(opts) @@ -182,11 +203,11 @@ func (ps *propertiesService) ReplaceAllProperties( if store, ok := qtx.(db.Store); ok { return store.WithTransactionErr(func(qtx db.ExtendQuerier) error { - return ps.replaceAllPropertiesWithQuerier(ctx, entityID, props, qtx) + return ps.replaceAllPropertiesWithQuerier(ctx, entityID, props, projectID, providerID, qtx) }) } - return ps.replaceAllPropertiesWithQuerier(ctx, entityID, props, qtx) + return ps.replaceAllPropertiesWithQuerier(ctx, entityID, props, projectID, providerID, qtx) } func (ps *propertiesService) SaveAllProperties( @@ -205,9 +226,15 @@ func (ps *propertiesService) SaveAllProperties( } func (ps *propertiesService) replaceAllPropertiesWithQuerier( - ctx context.Context, entityID uuid.UUID, props *properties.Properties, qtx db.ExtendQuerier, + ctx context.Context, entityID uuid.UUID, props *properties.Properties, + projectID uuid.UUID, providerID uuid.UUID, qtx db.ExtendQuerier, ) error { - err := qtx.DeleteAllPropertiesForEntity(ctx, entityID) + err := qtx.DeleteAllPropertiesForEntity(ctx, db.DeleteAllPropertiesForEntityParams{ + EntityID: entityID, + ProjectID: projectID, + ProviderID: providerID, + }) + if err != nil { return fmt.Errorf("failed to delete properties: %w", err) } @@ -234,13 +261,16 @@ func (*propertiesService) saveAllPropertiesWithQuerier( func (ps *propertiesService) ReplaceProperty( ctx context.Context, entityID uuid.UUID, key string, prop *properties.Property, + projectID uuid.UUID, providerID uuid.UUID, opts *CallOptions, ) error { qtx := ps.getStoreOrTransaction(opts) if prop == nil { return qtx.DeleteProperty(ctx, db.DeletePropertyParams{ - EntityID: entityID, - Key: key, + EntityID: entityID, + Key: key, + ProjectID: projectID, + ProviderID: providerID, }) } @@ -259,7 +289,12 @@ func (ps *propertiesService) getEntityWithProperties( ) (*models.EntityWithProperties, error) { q := ps.getStoreOrTransaction(opts) - dbProps, err := q.GetAllPropertiesForEntity(ctx, ent.ID) + dbProps, err := q.GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{ + EntityID: ent.ID, + ProjectID: ent.ProjectID, + ProviderID: ent.ProviderID, + }) + if errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("failed to get properties for entity: %w", ErrEntityNotFound) } else if err != nil { @@ -286,12 +321,17 @@ func (ps *propertiesService) getEntityWithProperties( } func (ps *propertiesService) EntityWithPropertiesByID( - ctx context.Context, entityID uuid.UUID, + ctx context.Context, entityID uuid.UUID, projectID uuid.UUID, providerID uuid.UUID, opts *CallOptions, ) (*models.EntityWithProperties, error) { q := ps.getStoreOrTransaction(opts) - ent, err := q.GetEntityByID(ctx, entityID) + ent, err := q.GetEntityByID(ctx, db.GetEntityByIDParams{ + ID: entityID, + ProjectID: projectID, + ProviderID: providerID, + }) + if errors.Is(err, sql.ErrNoRows) { return nil, ErrEntityNotFound } else if err != nil { diff --git a/internal/entities/properties/service/service_test.go b/internal/entities/properties/service/service_test.go index c92744be6f..49eabfbdb7 100644 --- a/internal/entities/properties/service/service_test.go +++ b/internal/entities/properties/service/service_test.go @@ -276,6 +276,10 @@ func TestPropertiesService_SaveProperty(t *testing.T) { require.NoError(t, err) propSvc := NewPropertiesService(tctx.testQueries) + if tt.dbSetup != nil { + tt.dbSetup(t, ent.ID, tctx.testQueries) + } + var prop *properties.Property if tt.val != nil { prop, err = properties.NewProperty(tt.val) @@ -283,14 +287,16 @@ func TestPropertiesService_SaveProperty(t *testing.T) { } err = tctx.testQueries.WithTransactionErr(func(qtx db.ExtendQuerier) error { - return propSvc.ReplaceProperty(ctx, ent.ID, tt.key, prop, + return propSvc.ReplaceProperty(ctx, ent.ID, tt.key, prop, tctx.dbProj.ID, tctx.ghAppProvider.ID, CallBuilder().WithStoreOrTransaction(qtx)) }) require.NoError(t, err) dbProp, err := tctx.testQueries.GetProperty(ctx, db.GetPropertyParams{ - EntityID: ent.ID, - Key: tt.key, + EntityID: ent.ID, + Key: tt.key, + ProjectID: tctx.dbProj.ID, + ProviderID: tctx.ghAppProvider.ID, }) if tt.val == nil { require.ErrorIs(t, err, sql.ErrNoRows) @@ -429,15 +435,23 @@ func TestPropertiesService_SaveAllProperties(t *testing.T) { require.NoError(t, err) propSvc := NewPropertiesService(tctx.testQueries) + if tt.dbSetup != nil { + tt.dbSetup(t, ent.ID, tctx.testQueries) + } + props := properties.NewProperties(tt.props) err = tctx.testQueries.WithTransactionErr(func(qtx db.ExtendQuerier) error { - return propSvc.ReplaceAllProperties(ctx, ent.ID, props, + return propSvc.ReplaceAllProperties(ctx, ent.ID, props, tctx.dbProj.ID, tctx.ghAppProvider.ID, CallBuilder().WithStoreOrTransaction(qtx)) }) require.NoError(t, err) - dbProps, err := tctx.testQueries.GetAllPropertiesForEntity(ctx, ent.ID) + dbProps, err := tctx.testQueries.GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{ + EntityID: ent.ID, + ProjectID: tctx.dbProj.ID, + ProviderID: tctx.ghAppProvider.ID, + }) require.NoError(t, err) updatedProps, err := models.DbPropsToModel(dbProps) @@ -904,50 +918,51 @@ func TestPropertiesService_EntityWithProperties(t *testing.T) { scenarios := []struct { name string entityID uuid.UUID + projectID uuid.UUID + providerID uuid.UUID entName string - dbEntBuilder func(id uuid.UUID, entName string) db.EntityInstance + dbEntBuilder func(id uuid.UUID, projID uuid.UUID, provID uuid.UUID, entName string) db.EntityInstance dbPropsBuilder func(id uuid.UUID) []db.Property checkProps func(t *testing.T, props *properties.Properties) }{ { - name: "Entity with properties", - entityID: uuid.New(), - entName: "myorg/the-props-are-different", - dbEntBuilder: func(id uuid.UUID, entName string) db.EntityInstance { + name: "Entity with properties", + entityID: uuid.New(), + projectID: uuid.New(), + providerID: uuid.New(), + entName: "myorg/the-props-are-different", + dbEntBuilder: func(id uuid.UUID, projID uuid.UUID, provID uuid.UUID, entName string) db.EntityInstance { return db.EntityInstance{ - ID: id, - Name: entName, + ID: id, + ProjectID: projID, + ProviderID: provID, + Name: entName, } }, dbPropsBuilder: func(id uuid.UUID) []db.Property { return []db.Property{ - { - EntityID: id, - Key: "name", - Value: []byte(`{"value": "myorg/bad-go", "version": "v1"}`), - }, - { - EntityID: id, - Key: "is_private", - Value: []byte(`{"value": false, "version": "v1"}`), - }, + {EntityID: id, Key: "name", Value: []byte(`{"value": "myorg/bad-go", "version": "v1"}`)}, + {EntityID: id, Key: "is_private", Value: []byte(`{"value": false, "version": "v1"}`)}, } }, checkProps: func(t *testing.T, props *properties.Properties) { t.Helper() - require.Equal(t, props.GetProperty("name").GetString(), "myorg/bad-go") require.Equal(t, props.GetProperty("is_private").GetBool(), false) }, }, { - name: "Entity without properties", - entityID: uuid.New(), - entName: "myorg/noprops", - dbEntBuilder: func(id uuid.UUID, entName string) db.EntityInstance { + name: "Entity without properties", + entityID: uuid.New(), + projectID: uuid.New(), + providerID: uuid.New(), + entName: "myorg/noprops", + dbEntBuilder: func(id uuid.UUID, projID uuid.UUID, provID uuid.UUID, entName string) db.EntityInstance { return db.EntityInstance{ - ID: id, - Name: entName, + ID: id, + ProjectID: projID, + ProviderID: provID, + Name: entName, } }, dbPropsBuilder: func(_ uuid.UUID) []db.Property { @@ -955,7 +970,6 @@ func TestPropertiesService_EntityWithProperties(t *testing.T) { }, checkProps: func(t *testing.T, props *properties.Properties) { t.Helper() - require.Equal(t, props.GetProperty("name").GetString(), "myorg/noprops") }, }, @@ -971,14 +985,14 @@ func TestPropertiesService_EntityWithProperties(t *testing.T) { mockDB := mockdb.NewMockStore(ctrl) mockDB.EXPECT(). - GetEntityByID(ctx, tt.entityID). - Return(tt.dbEntBuilder(tt.entityID, tt.entName), nil) + GetEntityByID(ctx, db.GetEntityByIDParams{ID: tt.entityID, ProjectID: tt.projectID, ProviderID: tt.providerID}). + Return(tt.dbEntBuilder(tt.entityID, tt.projectID, tt.providerID, tt.entName), nil) mockDB.EXPECT(). - GetAllPropertiesForEntity(ctx, tt.entityID). + GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{EntityID: tt.entityID, ProjectID: tt.projectID, ProviderID: tt.providerID}). Return(tt.dbPropsBuilder(tt.entityID), nil) ps := NewPropertiesService(mockDB) - result, err := ps.EntityWithPropertiesByID(ctx, tt.entityID, nil) + result, err := ps.EntityWithPropertiesByID(ctx, tt.entityID, tt.projectID, tt.providerID, nil) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, result.Entity.ID, tt.entityID) @@ -987,7 +1001,6 @@ func TestPropertiesService_EntityWithProperties(t *testing.T) { }) } } - func TestPropertiesService_EntityWithProperties_WithCache(t *testing.T) { t.Parallel() @@ -1003,32 +1016,27 @@ func TestPropertiesService_EntityWithProperties_WithCache(t *testing.T) { mockDB := mockdb.NewMockStore(ctrl) entityID := uuid.New() + projectID := uuid.New() + providerID := uuid.New() entityName := "myorg/bad-go" entityRet := db.EntityInstance{ - ID: entityID, - Name: entityName, + ID: entityID, + ProjectID: projectID, + ProviderID: providerID, + Name: entityName, } propertyRet := []db.Property{ - { - EntityID: entityID, - Key: "name", - Value: []byte(`{"value": "myorg/bad-go", "version": "v1"}`), - }, - { - EntityID: entityID, - Key: "is_private", - Value: []byte(`{"value": false, "version": "v1"}`), - }, + {EntityID: entityID, Key: "name", Value: []byte(`{"value": "myorg/bad-go", "version": "v1"}`)}, + {EntityID: entityID, Key: "is_private", Value: []byte(`{"value": false, "version": "v1"}`)}, } - // we verify that the entity is only fetched once even though we call the service twice mockDB.EXPECT(). - GetEntityByID(ctx, entityID). + GetEntityByID(ctx, db.GetEntityByIDParams{ID: entityID, ProjectID: projectID, ProviderID: providerID}). Return(entityRet, nil). Times(1) mockDB.EXPECT(). - GetAllPropertiesForEntity(ctx, entityID). + GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{EntityID: entityID, ProjectID: projectID, ProviderID: providerID}). Return(propertyRet, nil). Times(1) @@ -1037,22 +1045,16 @@ func TestPropertiesService_EntityWithProperties_WithCache(t *testing.T) { require.NoError(t, err) t.Log("First call, no cache") - result, err := cps.EntityWithPropertiesByID(ctx, entityID, nil) + result, err := cps.EntityWithPropertiesByID(ctx, entityID, projectID, providerID, nil) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, result.Entity.ID, entityID) - require.Equal(t, result.Entity.Name, entityName) - require.Equal(t, result.Properties.GetProperty("name").GetString(), "myorg/bad-go") - require.Equal(t, result.Properties.GetProperty("is_private").GetBool(), false) t.Log("Second call, cache hit") - result, err = cps.EntityWithPropertiesByID(ctx, entityID, nil) + result, err = cps.EntityWithPropertiesByID(ctx, entityID, projectID, providerID, nil) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, result.Entity.ID, entityID) - require.Equal(t, result.Entity.Name, entityName) - require.Equal(t, result.Properties.GetProperty("name").GetString(), "myorg/bad-go") - require.Equal(t, result.Properties.GetProperty("is_private").GetBool(), false) }) t.Run("Errors are propagated", func(t *testing.T) { @@ -1061,9 +1063,11 @@ func TestPropertiesService_EntityWithProperties_WithCache(t *testing.T) { mockDB := mockdb.NewMockStore(ctrl) entityID := uuid.New() + projectID := uuid.New() + providerID := uuid.New() mockDB.EXPECT(). - GetEntityByID(ctx, entityID). + GetEntityByID(ctx, db.GetEntityByIDParams{ID: entityID, ProjectID: projectID, ProviderID: providerID}). Return(db.EntityInstance{}, ErrEntityNotFound). Times(1) @@ -1071,7 +1075,7 @@ func TestPropertiesService_EntityWithProperties_WithCache(t *testing.T) { cps, err := WithEntityCache(ps, 100) require.NoError(t, err) - result, err := cps.EntityWithPropertiesByID(ctx, entityID, nil) + result, err := cps.EntityWithPropertiesByID(ctx, entityID, projectID, providerID, nil) require.ErrorIs(t, err, ErrEntityNotFound) require.Nil(t, result) }) @@ -1089,6 +1093,8 @@ func TestPropertiesService_MultiPropertyWrites_EnsureTransaction(t *testing.T) { ctx := context.Background() entityID := uuid.New() + projectID := uuid.New() + providerID := uuid.New() props := properties.NewProperties(map[string]any{ "name": "repo-1", "is_private": true, @@ -1096,21 +1102,21 @@ func TestPropertiesService_MultiPropertyWrites_EnsureTransaction(t *testing.T) { scenarios := []struct { name string - invoke func(PropertiesService, context.Context, uuid.UUID, *properties.Properties, *CallOptions) error + invoke func(PropertiesService, context.Context, uuid.UUID, uuid.UUID, uuid.UUID, *properties.Properties, *CallOptions) error expectDeleteBeforePut bool }{ { name: "replace all properties", - invoke: func(ps PropertiesService, ctx context.Context, entityID uuid.UUID, + invoke: func(ps PropertiesService, ctx context.Context, entityID uuid.UUID, projectID uuid.UUID, providerID uuid.UUID, props *properties.Properties, opts *CallOptions, ) error { - return ps.ReplaceAllProperties(ctx, entityID, props, opts) + return ps.ReplaceAllProperties(ctx, entityID, props, projectID, providerID, opts) }, expectDeleteBeforePut: true, }, { name: "save all properties", - invoke: func(ps PropertiesService, ctx context.Context, entityID uuid.UUID, + invoke: func(ps PropertiesService, ctx context.Context, entityID uuid.UUID, _ uuid.UUID, _ uuid.UUID, props *properties.Properties, opts *CallOptions, ) error { return ps.SaveAllProperties(ctx, entityID, props, opts) @@ -1137,7 +1143,11 @@ func TestPropertiesService_MultiPropertyWrites_EnsureTransaction(t *testing.T) { }) if tt.expectDeleteBeforePut { mockTxQuerier.EXPECT(). - DeleteAllPropertiesForEntity(ctx, entityID). + DeleteAllPropertiesForEntity(ctx, db.DeleteAllPropertiesForEntityParams{ + EntityID: entityID, + ProjectID: projectID, + ProviderID: providerID, + }). Return(nil) } mockTxQuerier.EXPECT(). @@ -1146,7 +1156,7 @@ func TestPropertiesService_MultiPropertyWrites_EnsureTransaction(t *testing.T) { Times(2) ps := NewPropertiesService(mockStore) - err := tt.invoke(ps, ctx, entityID, props, nil) + err := tt.invoke(ps, ctx, entityID, projectID, providerID, props, nil) require.NoError(t, err) }) @@ -1170,11 +1180,15 @@ func TestPropertiesService_MultiPropertyWrites_EnsureTransaction(t *testing.T) { txQuerier := tctx.testQueries.GetQuerierWithTransaction(tx) ps := NewPropertiesService(tctx.testQueries) - err = tt.invoke(ps, ctx, ent.ID, props, + err = tt.invoke(ps, ctx, ent.ID, tctx.dbProj.ID, tctx.ghAppProvider.ID, props, CallBuilder().WithStoreOrTransaction(txQuerier)) require.NoError(t, err) - stored, err := txQuerier.GetAllPropertiesForEntity(ctx, ent.ID) + stored, err := txQuerier.GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{ + EntityID: ent.ID, + ProjectID: tctx.dbProj.ID, + ProviderID: tctx.ghAppProvider.ID, + }) require.NoError(t, err) require.Len(t, stored, 2) }) @@ -1193,7 +1207,7 @@ func TestPropertiesService_MultiPropertyWrites_EnsureTransaction(t *testing.T) { Return(expectedErr) ps := NewPropertiesService(mockStore) - err := tt.invoke(ps, ctx, entityID, props, nil) + err := tt.invoke(ps, ctx, entityID, projectID, providerID, props, nil) require.ErrorIs(t, err, expectedErr) }) } @@ -1222,10 +1236,14 @@ func TestPropertiesService_MultiPropertyWrites_EnsureTransaction(t *testing.T) { err: errors.New("forced upsert failure"), }) - err = ps.ReplaceAllProperties(ctx, ent.ID, props, nil) + err = ps.ReplaceAllProperties(ctx, ent.ID, props, tctx.dbProj.ID, tctx.ghAppProvider.ID, nil) require.ErrorContains(t, err, "forced upsert failure") - stored, err := tctx.testQueries.GetAllPropertiesForEntity(ctx, ent.ID) + stored, err := tctx.testQueries.GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{ + EntityID: ent.ID, + ProjectID: tctx.dbProj.ID, + ProviderID: tctx.ghAppProvider.ID, + }) require.NoError(t, err) restoredProps, err := models.DbPropsToModel(stored) diff --git a/internal/entities/service/entity_creator.go b/internal/entities/service/entity_creator.go index ae912a648b..ec88e16dd9 100644 --- a/internal/entities/service/entity_creator.go +++ b/internal/entities/service/entity_creator.go @@ -168,7 +168,7 @@ func (e *entityCreator) CreateEntity( // Replace properties - use Replace to ensure a clean slate // (removes any stale properties from previous failed attempts) - if err := e.propSvc.ReplaceAllProperties(ctx, ent.ID, registeredProps, + if err := e.propSvc.ReplaceAllProperties(ctx, ent.ID, registeredProps, projectID, provider.ID, propService.CallBuilder().WithStoreOrTransaction(t)); err != nil { return nil, fmt.Errorf("error saving properties: %w", err) } diff --git a/internal/entities/service/entity_creator_simple_test.go b/internal/entities/service/entity_creator_simple_test.go index dd29726f57..22dc9357b8 100644 --- a/internal/entities/service/entity_creator_simple_test.go +++ b/internal/entities/service/entity_creator_simple_test.go @@ -269,11 +269,15 @@ func TestEntityCreator_Integration_HappyPath(t *testing.T) { res, err := creator.CreateEntity(ctx, &dbProvider, project.ID, pb.Entity_ENTITY_REPOSITORIES, nil, nil) require.NoError(t, err) - dbEnt, err := realStore.GetEntityByID(ctx, res.Entity.ID) + dbEnt, err := realStore.GetEntityByID(ctx, db.GetEntityByIDParams{ + ID: res.Entity.ID, + ProjectID: project.ID, + ProviderID: dbProvider.ID, + }) require.NoError(t, err) assert.Equal(t, db.EntitiesRepository, dbEnt.EntityType) - savedEntity, _ := realPropSvc.EntityWithPropertiesByID(ctx, res.Entity.ID, nil) + savedEntity, _ := realPropSvc.EntityWithPropertiesByID(ctx, res.Entity.ID, project.ID, dbProvider.ID, nil) assert.Contains(t, fmt.Sprintf("%+v", savedEntity.Properties), "my-test-repo") } @@ -319,7 +323,11 @@ func TestEntityCreator_Integration_WithOriginator(t *testing.T) { }) require.NoError(t, err) - dbEnt, _ := realStore.GetEntityByID(ctx, res.Entity.ID) + dbEnt, _ := realStore.GetEntityByID(ctx, db.GetEntityByIDParams{ + ID: res.Entity.ID, + ProjectID: project.ID, + ProviderID: dbProvider.ID, + }) assert.Equal(t, parentID, dbEnt.OriginatedFrom.UUID) } @@ -353,7 +361,7 @@ func TestEntityCreator_Integration_RollbackCleanup(t *testing.T) { mockProv.EXPECT().RegisterEntity(gomock.Any(), gomock.Any(), gomock.Any()).Return(testProps, nil) // db save fails - mockPropSvc.EXPECT().ReplaceAllProperties(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + mockPropSvc.EXPECT().ReplaceAllProperties(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(errors.New("database explosion")) // ensure we match the props being deregistered diff --git a/internal/entities/service/mock/service.go b/internal/entities/service/mock/service.go index 484285146d..af5462ac55 100644 --- a/internal/entities/service/mock/service.go +++ b/internal/entities/service/mock/service.go @@ -43,32 +43,32 @@ func (m *MockEntityService) EXPECT() *MockEntityServiceMockRecorder { } // DeleteEntityByID mocks base method. -func (m *MockEntityService) DeleteEntityByID(ctx context.Context, entityID, projectID uuid.UUID) error { +func (m *MockEntityService) DeleteEntityByID(ctx context.Context, entityID, projectID, providerID uuid.UUID) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteEntityByID", ctx, entityID, projectID) + ret := m.ctrl.Call(m, "DeleteEntityByID", ctx, entityID, projectID, providerID) ret0, _ := ret[0].(error) return ret0 } // DeleteEntityByID indicates an expected call of DeleteEntityByID. -func (mr *MockEntityServiceMockRecorder) DeleteEntityByID(ctx, entityID, projectID any) *gomock.Call { +func (mr *MockEntityServiceMockRecorder) DeleteEntityByID(ctx, entityID, projectID, providerID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEntityByID", reflect.TypeOf((*MockEntityService)(nil).DeleteEntityByID), ctx, entityID, projectID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEntityByID", reflect.TypeOf((*MockEntityService)(nil).DeleteEntityByID), ctx, entityID, projectID, providerID) } // GetEntityByID mocks base method. -func (m *MockEntityService) GetEntityByID(ctx context.Context, entityID, projectID uuid.UUID) (*v1.EntityInstance, error) { +func (m *MockEntityService) GetEntityByID(ctx context.Context, entityID, projectID, providerID uuid.UUID) (*v1.EntityInstance, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEntityByID", ctx, entityID, projectID) + ret := m.ctrl.Call(m, "GetEntityByID", ctx, entityID, projectID, providerID) ret0, _ := ret[0].(*v1.EntityInstance) ret1, _ := ret[1].(error) return ret0, ret1 } // GetEntityByID indicates an expected call of GetEntityByID. -func (mr *MockEntityServiceMockRecorder) GetEntityByID(ctx, entityID, projectID any) *gomock.Call { +func (mr *MockEntityServiceMockRecorder) GetEntityByID(ctx, entityID, projectID, providerID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntityByID", reflect.TypeOf((*MockEntityService)(nil).GetEntityByID), ctx, entityID, projectID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntityByID", reflect.TypeOf((*MockEntityService)(nil).GetEntityByID), ctx, entityID, projectID, providerID) } // GetEntityByName mocks base method. diff --git a/internal/entities/service/service.go b/internal/entities/service/service.go index 7b3ff171fb..30f8f57547 100644 --- a/internal/entities/service/service.go +++ b/internal/entities/service/service.go @@ -46,6 +46,7 @@ type EntityService interface { ctx context.Context, entityID uuid.UUID, projectID uuid.UUID, + providerID uuid.UUID, ) (*pb.EntityInstance, error) // GetEntityByName retrieves an entity by its name @@ -62,6 +63,7 @@ type EntityService interface { ctx context.Context, entityID uuid.UUID, projectID uuid.UUID, + providerID uuid.UUID, ) error } @@ -158,7 +160,7 @@ func (s *entityService) ListEntities( break } - ewp, err := s.propSvc.EntityWithPropertiesByID(ctx, entity.ID, + ewp, err := s.propSvc.EntityWithPropertiesByID(ctx, entity.ID, projectID, providerID, propService.CallBuilder().WithStoreOrTransaction(qtx)) if err != nil { return nil, "", fmt.Errorf("error fetching properties for entity: %w", err) @@ -188,9 +190,15 @@ func (s *entityService) GetEntityByID( ctx context.Context, entityID uuid.UUID, projectID uuid.UUID, + providerID uuid.UUID, ) (*pb.EntityInstance, error) { // Get entity from database - entity, err := s.store.GetEntityByID(ctx, entityID) + _, err := s.store.GetEntityByID(ctx, db.GetEntityByIDParams{ + ID: entityID, + ProjectID: projectID, + ProviderID: providerID, + }) + if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, status.Errorf(codes.NotFound, "entity not found") @@ -198,13 +206,8 @@ func (s *entityService) GetEntityByID( return nil, fmt.Errorf("error fetching entity: %w", err) } - // Verify entity belongs to the project - if entity.ProjectID != projectID { - return nil, status.Errorf(codes.NotFound, "entity not found in project") - } - // Get properties - ewp, err := s.propSvc.EntityWithPropertiesByID(ctx, entityID, nil) + ewp, err := s.propSvc.EntityWithPropertiesByID(ctx, entityID, projectID, providerID, nil) if err != nil { return nil, fmt.Errorf("error fetching properties for entity: %w", err) } @@ -246,7 +249,7 @@ func (s *entityService) GetEntityByName( } // Get properties - ewp, err := s.propSvc.EntityWithPropertiesByID(ctx, entity.ID, nil) + ewp, err := s.propSvc.EntityWithPropertiesByID(ctx, entity.ID, projectID, providerID, nil) if err != nil { return nil, fmt.Errorf("error fetching properties for entity: %w", err) } @@ -264,9 +267,14 @@ func (s *entityService) DeleteEntityByID( ctx context.Context, entityID uuid.UUID, projectID uuid.UUID, + providerID uuid.UUID, ) error { // Get entity to verify it exists and belongs to the project - entity, err := s.store.GetEntityByID(ctx, entityID) + _, err := s.store.GetEntityByID(ctx, db.GetEntityByIDParams{ + ID: entityID, + ProjectID: projectID, + ProviderID: providerID, + }) if err != nil { if errors.Is(err, sql.ErrNoRows) { return status.Errorf(codes.NotFound, "entity not found") @@ -274,11 +282,6 @@ func (s *entityService) DeleteEntityByID( return fmt.Errorf("error fetching entity: %w", err) } - // Verify entity belongs to the project - if entity.ProjectID != projectID { - return status.Errorf(codes.NotFound, "entity not found in project") - } - // Delete entity and its properties in a transaction tx, err := s.store.BeginTransaction() if err != nil { @@ -293,14 +296,19 @@ func (s *entityService) DeleteEntityByID( qtx := s.store.GetQuerierWithTransaction(tx) // Delete properties first - if err := qtx.DeleteAllPropertiesForEntity(ctx, entityID); err != nil { + if err := qtx.DeleteAllPropertiesForEntity(ctx, db.DeleteAllPropertiesForEntityParams{ + EntityID: entityID, + ProjectID: projectID, + ProviderID: providerID, + }); err != nil { return fmt.Errorf("error deleting entity properties: %w", err) } // Delete entity if err := qtx.DeleteEntity(ctx, db.DeleteEntityParams{ - ID: entityID, - ProjectID: projectID, + ID: entityID, + ProjectID: projectID, + ProviderID: providerID, }); err != nil { return fmt.Errorf("error deleting entity: %w", err) } diff --git a/internal/history/service.go b/internal/history/service.go index 71b857981a..739ebb69a9 100644 --- a/internal/history/service.go +++ b/internal/history/service.go @@ -234,7 +234,11 @@ func (ehs *evaluationHistoryService) ListEvaluationHistory( data := make([]*OneEvalHistoryAndEntity, 0, len(rows)) for _, row := range rows { - ewp, err := psc.EntityWithPropertiesByID(ctx, row.EntityID, + ewp, err := psc.EntityWithPropertiesByID( + ctx, + row.EntityID, + row.ProjectID, + row.ProviderID, propertiessvc.CallBuilder().WithStoreOrTransaction(qtx)) if err != nil { return nil, fmt.Errorf("error fetching entity for properties: %w", err) diff --git a/internal/history/service_test.go b/internal/history/service_test.go index c70fbdc7ef..73df15e001 100644 --- a/internal/history/service_test.go +++ b/internal/history/service_test.go @@ -825,12 +825,12 @@ func TestListEvaluationHistory(t *testing.T) { pm := pmMock.NewMockProviderManager(ctrl) propsSvc := propsSvcMock.NewMockPropertiesService(ctrl) for _, efp := range tt.efp { - propsSvc.EXPECT().EntityWithPropertiesByID(ctx, gomock.Any(), gomock.Any()). + propsSvc.EXPECT().EntityWithPropertiesByID(ctx, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(efp, tt.entityForPropertiesError) } if tt.entityForPropertiesError != nil && len(tt.efp) == 0 { - propsSvc.EXPECT().EntityWithPropertiesByID(ctx, gomock.Any(), gomock.Any()). + propsSvc.EXPECT().EntityWithPropertiesByID(ctx, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, tt.entityForPropertiesError).AnyTimes() } propsSvc.EXPECT().RetrieveAllPropertiesForEntity(ctx, gomock.Any(), gomock.Any(), gomock.Any()). diff --git a/internal/providers/github/manager/manager.go b/internal/providers/github/manager/manager.go index 142f672bfb..20434aae1a 100644 --- a/internal/providers/github/manager/manager.go +++ b/internal/providers/github/manager/manager.go @@ -178,13 +178,16 @@ func (g *githubProviderManager) Delete(ctx context.Context, config *db.Provider) return err } - entities, err := g.store.GetEntitiesByProvider(ctx, config.ID) + entities, err := g.store.GetEntitiesByProvider(ctx, db.GetEntitiesByProviderParams{ + ProviderID: config.ID, + Projects: []uuid.UUID{config.ProjectID}, + }) if err != nil { return fmt.Errorf("unable to retrieve list of entities to deregister: %w", err) } for _, ent := range entities { - ewp, err := g.propsSvc.EntityWithPropertiesByID(ctx, ent.ID, nil) + ewp, err := g.propsSvc.EntityWithPropertiesByID(ctx, ent.ID, config.ProjectID, config.ID, nil) if err != nil { zerolog.Ctx(ctx).Error().Err(err). Str("provider_id", config.ID.String()). diff --git a/internal/repositories/service.go b/internal/repositories/service.go index 1dea8b818a..a5e6b5119a 100644 --- a/internal/repositories/service.go +++ b/internal/repositories/service.go @@ -185,20 +185,37 @@ func (r *repositoryService) ListRepositories( return nil, fmt.Errorf("error fetching repositories: %w", err) } + if len(repoEnts) == 0 { + return nil, nil + } + + entityIDs := make([]uuid.UUID, len(repoEnts)) + for i, ent := range repoEnts { + entityIDs[i] = ent.ID + } + + dbProps, err := qtx.GetPropertiesForEntities(ctx, db.GetPropertiesForEntitiesParams{ + EntityIds: entityIDs, + Projects: []uuid.UUID{projectID}, + ProviderID: providerID, + }) + if err != nil { + return nil, fmt.Errorf("error fetching bulk properties: %w", err) + } + + propsMap := make(map[uuid.UUID][]db.Property) + for _, p := range dbProps { + propsMap[p.EntityID] = append(propsMap[p.EntityID], p) + } + ents = make([]*models.EntityWithProperties, 0, len(repoEnts)) for _, ent := range repoEnts { - ewp, err := r.propSvc.EntityWithPropertiesByID(ctx, ent.ID, - service.CallBuilder().WithStoreOrTransaction(qtx)) + modelProps, err := models.DbPropsToModel(propsMap[ent.ID]) if err != nil { - return nil, fmt.Errorf("error fetching properties for repository: %w", err) - } - - if err := r.propSvc.RetrieveAllPropertiesForEntity(ctx, ewp, r.providerManager, - service.ReadBuilder().WithStoreOrTransaction(qtx).TolerateStaleData()); err != nil { - return nil, fmt.Errorf("error fetching properties for repository: %w", err) + return nil, fmt.Errorf("failed to convert properties for %s: %w", ent.ID, err) } - ents = append(ents, ewp) + ents = append(ents, models.NewEntityWithProperties(ent, modelProps)) } // We care about commiting the transaction since the `RetrieveAllPropertiesForEntity` @@ -215,30 +232,36 @@ func (r *repositoryService) GetRepositoryById( repositoryID uuid.UUID, projectID uuid.UUID, ) (*pb.Repository, error) { - ewp, err := r.propSvc.EntityWithPropertiesByID(ctx, repositoryID, nil) + + ent, err := r.store.GetEntityByID(ctx, db.GetEntityByIDParams{ + ID: repositoryID, + ProjectID: projectID, + }) if err != nil { - return nil, fmt.Errorf("error fetching repository: %w", err) + return nil, fmt.Errorf("error fetching repository instance: %w", err) } - // Verify the entity belongs to the correct project - if ewp.Entity.ProjectID != projectID { - return nil, sql.ErrNoRows + dbProps, err := r.store.GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{ + EntityID: repositoryID, + ProjectID: projectID, + ProviderID: ent.ProviderID, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("error fetching properties: %w", err) } - // Verify it's a repository entity - if ewp.Entity.Type != pb.Entity_ENTITY_REPOSITORIES { - return nil, fmt.Errorf("entity is not a repository") - } + modelProps, _ := models.DbPropsToModel(dbProps) + ewp := models.NewEntityWithProperties(ent, modelProps) // Retrieve all properties from provider if err := r.propSvc.RetrieveAllPropertiesForEntity(ctx, ewp, r.providerManager, nil); err != nil { - return nil, fmt.Errorf("error fetching properties for repository: %w", err) + return nil, fmt.Errorf("error fetching properties: %w", err) } // Convert to protobuf somePB, err := r.propSvc.EntityWithPropertiesAsProto(ctx, ewp, r.providerManager) if err != nil { - return nil, fmt.Errorf("error converting entity to protobuf: %w", err) + return nil, fmt.Errorf("error converting entity: %w", err) } pbRepo, ok := somePB.(*pb.Repository) @@ -291,12 +314,20 @@ func (r *repositoryService) GetRepositoryByName( return nil, sql.ErrNoRows } - // Use the first matching entity - ewp, err := r.propSvc.EntityWithPropertiesByID(ctx, entities[0].ID, nil) - if err != nil { - return nil, fmt.Errorf("error fetching repository: %w", err) + ent := entities[0] + + dbProps, err := r.store.GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{ + EntityID: ent.ID, + ProjectID: ent.ProjectID, + ProviderID: ent.ProviderID, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("error fetching properties: %w", err) } + modelProps, _ := models.DbPropsToModel(dbProps) + ewp := models.NewEntityWithProperties(ent, modelProps) + // Retrieve all properties from provider if err := r.propSvc.RetrieveAllPropertiesForEntity(ctx, ewp, r.providerManager, nil); err != nil { return nil, fmt.Errorf("error fetching properties for repository: %w", err) @@ -320,11 +351,26 @@ func (r *repositoryService) DeleteByID(ctx context.Context, repositoryID uuid.UU logger.BusinessRecord(ctx).Project = projectID logger.BusinessRecord(ctx).Repository = repositoryID - ent, err := r.propSvc.EntityWithPropertiesByID(ctx, repositoryID, nil) + entInstance, err := r.store.GetEntityByID(ctx, db.GetEntityByIDParams{ + ID: repositoryID, + ProjectID: projectID, + }) if err != nil { - return fmt.Errorf("error fetching repository: %w", err) + return fmt.Errorf("error fetching repository instance: %w", err) + } + + dbProps, err := r.store.GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{ + EntityID: repositoryID, + ProjectID: projectID, + ProviderID: entInstance.ProviderID, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("error fetching properties: %w", err) } + modelProps, _ := models.DbPropsToModel(dbProps) + ent := models.NewEntityWithProperties(entInstance, modelProps) + logger.BusinessRecord(ctx).ProviderID = ent.Entity.ProviderID prov, err := r.providerManager.InstantiateFromID(ctx, ent.Entity.ProviderID) @@ -344,11 +390,9 @@ func (r *repositoryService) DeleteByName( ) error { logger.BusinessRecord(ctx).Project = projectID - // Build the full repository name fullName := fmt.Sprintf("%s/%s", repoOwner, repoName) - - // Get provider ID from name if specified var providerID uuid.UUID + if providerName != "" { prov, err := r.store.GetProviderByName(ctx, db.GetProviderByNameParams{ Name: providerName, @@ -376,15 +420,22 @@ func (r *repositoryService) DeleteByName( } if len(entities) == 0 { - return fmt.Errorf("error retrieving repository %s/%s in project %s: %w", repoOwner, repoName, projectID, sql.ErrNoRows) + return fmt.Errorf("repository not found: %w", sql.ErrNoRows) } - // Fetch the entity with properties - ent, err := r.propSvc.EntityWithPropertiesByID(ctx, entities[0].ID, nil) - if err != nil { - return fmt.Errorf("error fetching repository: %w", err) + entInstance := entities[0] + dbProps, err := r.store.GetAllPropertiesForEntity(ctx, db.GetAllPropertiesForEntityParams{ + EntityID: entInstance.ID, + ProjectID: entInstance.ProjectID, + ProviderID: entInstance.ProviderID, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("error fetching properties: %w", err) } + modelProps, _ := models.DbPropsToModel(dbProps) + ent := models.NewEntityWithProperties(entInstance, modelProps) + logger.BusinessRecord(ctx).Repository = ent.Entity.ID prov, err := r.providerManager.InstantiateFromID(ctx, ent.Entity.ProviderID) @@ -394,7 +445,6 @@ func (r *repositoryService) DeleteByName( return r.deleteRepository(ctx, prov, ent) } - func (r *repositoryService) deleteRepository( ctx context.Context, client provifv1.Provider, repo *models.EntityWithProperties, ) error { @@ -410,8 +460,9 @@ func (r *repositoryService) deleteRepository( _, err = db.WithTransaction(r.store, func(t db.ExtendQuerier) (*pb.Repository, error) { // Remove the entity from the DB if err := t.DeleteEntity(ctx, db.DeleteEntityParams{ - ID: repo.Entity.ID, - ProjectID: repo.Entity.ProjectID, + ID: repo.Entity.ID, + ProjectID: repo.Entity.ProjectID, + ProviderID: repo.Entity.ProviderID, }); err != nil { return nil, fmt.Errorf("error deleting entity from DB: %w", err) } diff --git a/internal/repositories/service_test.go b/internal/repositories/service_test.go index 98e0da87b8..3aa369cb36 100644 --- a/internal/repositories/service_test.go +++ b/internal/repositories/service_test.go @@ -171,8 +171,8 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { { Name: "DeleteByName fails when repo's entity cannot be retrieved", DeleteType: ByName, - DBSetup: newDBMock(withSuccessfulGetByName), - ServiceSetup: newPropSvcMock(withFailedEntityWithProps), + DBSetup: newDBMock(withFailedGetPropsByName), + ServiceSetup: newPropSvcMock(), ProviderSetup: newProviderMock(), ProviderManagerSetup: func(_ provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock() @@ -183,7 +183,7 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { Name: "DeleteByName fails when provider cannot be instantiated", DeleteType: ByName, DBSetup: newDBMock(withSuccessfulGetByName), - ServiceSetup: newPropSvcMock(withSuccessfulEntityWithProps), + ServiceSetup: newPropSvcMock(), ProviderManagerSetup: func(_ provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock(pf.WithFailedInstantiateFromID) }, @@ -194,7 +194,7 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { Name: "DeleteByName still works when entity cannot be deregistered", DeleteType: ByName, DBSetup: newDBMock(withSuccessfulGetByName, withSuccessfulDelete), - ServiceSetup: newPropSvcMock(withSuccessfulEntityWithProps), + ServiceSetup: newPropSvcMock(), ProviderManagerSetup: func(p provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock(pf.WithSuccessfulInstantiateFromID(p)) }, @@ -204,7 +204,7 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { Name: "DeleteByName by ID fails when repo cannot be deleted from DB", DeleteType: ByName, DBSetup: newDBMock(withSuccessfulGetByName, withFailedDelete), - ServiceSetup: newPropSvcMock(withSuccessfulEntityWithProps), + ServiceSetup: newPropSvcMock(), ProviderManagerSetup: func(p provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock(pf.WithSuccessfulInstantiateFromID(p)) }, @@ -215,7 +215,7 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { Name: "DeleteByName succeeds", DeleteType: ByName, DBSetup: newDBMock(withSuccessfulGetByName, withSuccessfulDelete), - ServiceSetup: newPropSvcMock(withSuccessfulEntityWithProps), + ServiceSetup: newPropSvcMock(), ProviderManagerSetup: func(p provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock(pf.WithSuccessfulInstantiateFromID(p)) }, @@ -224,8 +224,8 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { { Name: "DeleteByID fails when repo entity cannot be retrieved", DeleteType: ByID, - DBSetup: newDBMock(), - ServiceSetup: newPropSvcMock(withFailedEntityWithProps), + DBSetup: newDBMock(withFailedGetByID), + ServiceSetup: newPropSvcMock(), ProviderSetup: newProviderMock(), ProviderManagerSetup: func(_ provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock() @@ -235,8 +235,8 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { { Name: "DeleteByID fails when provider cannot be instantiated", DeleteType: ByID, - DBSetup: newDBMock(), - ServiceSetup: newPropSvcMock(withSuccessfulEntityWithProps), + DBSetup: newDBMock(withSuccessfulGetByID), + ServiceSetup: newPropSvcMock(), ProviderManagerSetup: func(_ provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock(pf.WithFailedInstantiateFromID) }, @@ -246,8 +246,8 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { { Name: "DeleteByID works when entity cannot be deregistered", DeleteType: ByID, - DBSetup: newDBMock(withSuccessfulDelete), - ServiceSetup: newPropSvcMock(withSuccessfulEntityWithProps), + DBSetup: newDBMock(withSuccessfulGetByID, withSuccessfulDelete), + ServiceSetup: newPropSvcMock(), ProviderManagerSetup: func(p provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock(pf.WithSuccessfulInstantiateFromID(p)) }, @@ -256,8 +256,8 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { { Name: "DeleteByID by ID fails when repo cannot be deleted from DB", DeleteType: ByID, - DBSetup: newDBMock(withFailedDelete), - ServiceSetup: newPropSvcMock(withSuccessfulEntityWithProps), + DBSetup: newDBMock(withSuccessfulGetByID, withFailedDelete), + ServiceSetup: newPropSvcMock(), ProviderManagerSetup: func(p provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock(pf.WithSuccessfulInstantiateFromID(p)) }, @@ -267,8 +267,8 @@ func TestRepositoryService_DeleteRepository(t *testing.T) { { Name: "DeleteByID succeeds", DeleteType: ByID, - DBSetup: newDBMock(withSuccessfulDelete), - ServiceSetup: newPropSvcMock(withSuccessfulEntityWithProps), + DBSetup: newDBMock(withSuccessfulGetByID, withSuccessfulDelete), + ServiceSetup: newPropSvcMock(), ProviderManagerSetup: func(p provinfv1.Provider) pf.ProviderManagerMockBuilder { return pf.NewProviderManagerMock(pf.WithSuccessfulInstantiateFromID(p)) }, @@ -308,17 +308,19 @@ func TestRepositoryService_GetRepositoryById(t *testing.T) { scenarios := []struct { Name string + DBSetup dbMockBuilder ServiceSetup propSvcMockBuilder ShouldSucceed bool }{ { Name: "Get by ID fails when DB call fails", - ServiceSetup: newPropSvcMock(withFailedEntityWithProps), + DBSetup: newDBMock(withFailedGetByID), + ServiceSetup: newPropSvcMock(), }, { - Name: "Get by ID succeeds", + Name: "Get by ID succeeds", + DBSetup: newDBMock(withSuccessfulGetByID), ServiceSetup: newPropSvcMock( - withSuccessfulEntityWithProps, withSuccessfulRetrieveAll, withSucessfulEntityToProto, ), @@ -333,7 +335,7 @@ func TestRepositoryService_GetRepositoryById(t *testing.T) { defer ctrl.Finish() ctx := context.Background() - svc := createService(ctrl, nil, scenario.ServiceSetup, nil, false) + svc := createService(ctrl, scenario.DBSetup, scenario.ServiceSetup, nil, false) _, err := svc.GetRepositoryById(ctx, repoID, projectID) if scenario.ShouldSucceed { @@ -363,7 +365,6 @@ func TestRepositoryService_GetRepositoryByName(t *testing.T) { Name: "Get by name succeeds", DBSetup: newDBMock(withSuccessfulGetByName), ServiceSetup: newPropSvcMock( - withSuccessfulEntityWithProps, withSuccessfulRetrieveAll, withSucessfulEntityToProto, ), @@ -400,7 +401,11 @@ func createService( var store db.Store if dbSetup != nil { store = dbSetup(ctrl) + } else { + // Provide a default mock store to prevent nil pointer panics! + store = mockdb.NewMockStore(ctrl) } + var providerManager manager.ProviderManager if providerSetup != nil { providerManager = providerSetup(ctrl) @@ -467,7 +472,7 @@ var ( ProjectID: projectID, RepoOwner: repoOwner, RepoName: repoName, - ProviderID: uuid.UUID{}, + ProviderID: uuid.New(), // Make sure ProviderID isn't nil } webhook = &gh.Hook{ ID: ptr.Ptr[int64](HookID), @@ -502,43 +507,60 @@ func newDBMock(opts ...func(dbMock)) dbMockBuilder { } } -func withFailedDelete(mock dbMock) { - mock.EXPECT().GetQuerierWithTransaction(gomock.Any()).Return(mock) - mock.EXPECT().BeginTransaction().Return(nil, nil) - mock.EXPECT(). - DeleteEntity(gomock.Any(), gomock.Any()). - Return(errDefault) - mock.EXPECT().Rollback(gomock.Any()).Return(nil) +func withFailedGetByID(mock dbMock) { + mock.EXPECT().GetEntityByID(gomock.Any(), gomock.Any()).Return(db.EntityInstance{}, errDefault).AnyTimes() } -func withSuccessfulDelete(mock dbMock) { - mock.EXPECT().GetQuerierWithTransaction(gomock.Any()).Return(mock) - mock.EXPECT().BeginTransaction().Return(nil, nil) - mock.EXPECT(). - DeleteEntity(gomock.Any(), gomock.Any()). - Return(nil) - mock.EXPECT().Commit(gomock.Any()).Return(nil) - mock.EXPECT().Rollback(gomock.Any()).Return(nil) +func withSuccessfulGetByID(mock dbMock) { + mock.EXPECT().GetEntityByID(gomock.Any(), gomock.Any()).Return(db.EntityInstance{ + ID: dbRepo.ID, + ProjectID: projectID, + ProviderID: dbRepo.ProviderID, + }, nil).AnyTimes() + mock.EXPECT().GetAllPropertiesForEntity(gomock.Any(), gomock.Any()).Return([]db.Property{}, nil).AnyTimes() +} + +func withSuccessfulGetByName(mock dbMock) { + mock.EXPECT().GetProviderByName(gomock.Any(), gomock.Any()).Return(provider, nil).AnyTimes() + mock.EXPECT().GetTypedEntitiesByPropertyV1(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return([]db.EntityInstance{{ + ID: dbRepo.ID, + ProjectID: projectID, + ProviderID: dbRepo.ProviderID, + }}, nil).AnyTimes() + mock.EXPECT().GetAllPropertiesForEntity(gomock.Any(), gomock.Any()).Return([]db.Property{}, nil).AnyTimes() } func withFailedGetByName(mock dbMock) { - // GetRepositoryByName now calls GetProviderByName + GetTypedEntitiesByPropertyV1 - mock.EXPECT(). - GetProviderByName(gomock.Any(), gomock.Any()). - Return(provider, nil) - mock.EXPECT(). - GetTypedEntitiesByPropertyV1(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return([]db.EntityInstance{}, errDefault) + mock.EXPECT().GetProviderByName(gomock.Any(), gomock.Any()).Return(provider, nil).AnyTimes() + mock.EXPECT().GetTypedEntitiesByPropertyV1(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return([]db.EntityInstance{}, errDefault).AnyTimes() } -func withSuccessfulGetByName(mock dbMock) { - // GetRepositoryByName now calls GetProviderByName + GetTypedEntitiesByPropertyV1 - mock.EXPECT(). - GetProviderByName(gomock.Any(), gomock.Any()). - Return(provider, nil) - mock.EXPECT(). - GetTypedEntitiesByPropertyV1(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return([]db.EntityInstance{{ID: dbRepo.ID}}, nil) +func withFailedGetPropsByName(mock dbMock) { + mock.EXPECT().GetProviderByName(gomock.Any(), gomock.Any()).Return(provider, nil).AnyTimes() + mock.EXPECT().GetTypedEntitiesByPropertyV1(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return([]db.EntityInstance{{ + ID: dbRepo.ID, + ProjectID: projectID, + ProviderID: dbRepo.ProviderID, + }}, nil).AnyTimes() + mock.EXPECT().GetAllPropertiesForEntity(gomock.Any(), gomock.Any()).Return(nil, errDefault).AnyTimes() +} + +func withFailedDelete(mock dbMock) { + mock.EXPECT().GetQuerierWithTransaction(gomock.Any()).Return(mock).AnyTimes() + mock.EXPECT().BeginTransaction().Return(nil, nil).AnyTimes() + mock.EXPECT().DeleteEntity(gomock.Any(), gomock.Any()).Return(errDefault).AnyTimes() + mock.EXPECT().Rollback(gomock.Any()).Return(nil).AnyTimes() +} + +func withSuccessfulDelete(mock dbMock) { + mock.EXPECT().GetQuerierWithTransaction(gomock.Any()).Return(mock).AnyTimes() + mock.EXPECT().BeginTransaction().Return(nil, nil).AnyTimes() + mock.EXPECT().DeleteEntity(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mock.EXPECT().Commit(gomock.Any()).Return(nil).AnyTimes() + mock.EXPECT().Rollback(gomock.Any()).Return(nil).AnyTimes() } func newGithubRepo(isPrivate bool) *gh.Repository { @@ -615,23 +637,6 @@ func newPropSvcMock(opts ...func(mock propSvcMock)) propSvcMockBuilder { } } -func withSuccessfulEntityWithProps(mock propSvcMock) { - mock.EXPECT(). - EntityWithPropertiesByID(gomock.Any(), gomock.Any(), gomock.Any()). - Return(models.NewEntityWithPropertiesFromInstance(models.EntityInstance{ - ID: dbRepo.ID, - Type: pb.Entity_ENTITY_REPOSITORIES, - ProjectID: projectID, - ProviderID: dbRepo.ProviderID, - }, publicProps), nil) -} - -func withFailedEntityWithProps(mock propSvcMock) { - mock.EXPECT(). - EntityWithPropertiesByID(gomock.Any(), gomock.Any(), gomock.Any()). - Return(nil, errDefault) -} - func withSucessfulEntityToProto(mock propSvcMock) { repo := instantiatePBRepo(false) mock.EXPECT().