diff --git a/cmd/bb_portal/BUILD.bazel b/cmd/bb_portal/BUILD.bazel index 32194d54..6a455bfa 100644 --- a/cmd/bb_portal/BUILD.bazel +++ b/cmd/bb_portal/BUILD.bazel @@ -10,7 +10,6 @@ go_library( importpath = "github.com/buildbarn/bb-portal/cmd/bb_portal", visibility = ["//visibility:private"], deps = [ - "//ent/gen/ent", "//ent/gen/ent/migrate", "//ent/gen/ent/runtime", "//internal/api/grpc/bes", diff --git a/cmd/bb_portal/main.go b/cmd/bb_portal/main.go index 9f8fd9eb..be5ad1d3 100644 --- a/cmd/bb_portal/main.go +++ b/cmd/bb_portal/main.go @@ -30,7 +30,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/buildbarn/bb-portal/ent/gen/ent" "github.com/buildbarn/bb-portal/ent/gen/ent/migrate" "github.com/buildbarn/bb-portal/internal/api/grpc/bes" "github.com/buildbarn/bb-portal/internal/api/http/bepuploader" @@ -164,7 +163,7 @@ func newBuildEventStreamService( } dbAuthService := dbauthservice.NewDbAuthService(dbClient.Ent(), clock.SystemClock, instanceNameAuthorizer, time.Second*5) - err = addGraphqlHandler(configuration, besConfiguration, dbAuthService, dependenciesGroup, grpcClientFactory, router, dbClient.Ent(), tracerProvider) + err = addGraphqlHandler(configuration, besConfiguration, dbAuthService, dependenciesGroup, grpcClientFactory, router, dbClient, tracerProvider) if err != nil { return util.StatusWrap(err, "Failed to add GraphQL handler for BuildEventStreamService") } @@ -207,7 +206,7 @@ func addGraphqlHandler( dependenciesGroup program.Group, grpcClientFactory bb_grpc.ClientFactory, router *mux.Router, - dbClient *ent.Client, + dbClient database.Client, tracerProvider trace.TracerProvider, ) error { srv := graphql.NewGraphqlHandler(dbClient, tracerProvider) diff --git a/internal/graphql/BUILD.bazel b/internal/graphql/BUILD.bazel index 5fd9d228..3ff60fb4 100644 --- a/internal/graphql/BUILD.bazel +++ b/internal/graphql/BUILD.bazel @@ -26,6 +26,7 @@ go_library( "//ent/gen/ent/invocationtarget", "//ent/gen/ent/sourcecontrol", "//ent/gen/ent/target", + "//internal/database", "//internal/graphql/helpers", "//internal/graphql/model", "//pkg/uuidgql", diff --git a/internal/graphql/custom.resolvers.go b/internal/graphql/custom.resolvers.go index 9c33d00d..33965d7f 100644 --- a/internal/graphql/custom.resolvers.go +++ b/internal/graphql/custom.resolvers.go @@ -56,7 +56,7 @@ func (r *connectionMetadataResolver) TimeSinceLastConnectionMillis(ctx context.C func (r *queryResolver) GetAuthenticatedUser(ctx context.Context, userUUID uuid.UUID) (*ent.AuthenticatedUser, error) { // CollectFields here is used to avoid the N+1 query problem. Ent shouldn't // need it, but somehow it does. - query, err := r.client.AuthenticatedUser.Query().Where(authenticateduser.UserUUID(userUUID)).CollectFields(ctx) + query, err := r.db.Ent().AuthenticatedUser.Query().Where(authenticateduser.UserUUID(userUUID)).CollectFields(ctx) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (r *queryResolver) GetAuthenticatedUser(ctx context.Context, userUUID uuid. func (r *queryResolver) GetBazelInvocation(ctx context.Context, invocationID uuid.UUID) (*ent.BazelInvocation, error) { // CollectFields here is used to avoid the N+1 query problem. Ent shouldn't // need it, but somehow it does. - query, err := r.client.BazelInvocation.Query().Where(bazelinvocation.InvocationID(invocationID)).CollectFields(ctx) + query, err := r.db.Ent().BazelInvocation.Query().Where(bazelinvocation.InvocationID(invocationID)).CollectFields(ctx) if err != nil { return nil, err } @@ -78,7 +78,7 @@ func (r *queryResolver) GetBazelInvocation(ctx context.Context, invocationID uui func (r *queryResolver) GetBuild(ctx context.Context, buildUUID uuid.UUID) (*ent.Build, error) { // CollectFields here is used to avoid the N+1 query problem. Ent shouldn't // need it, but somehow it does. - query, err := r.client.Build.Query().Where(build.BuildUUID(buildUUID)).CollectFields(ctx) + query, err := r.db.Ent().Build.Query().Where(build.BuildUUID(buildUUID)).CollectFields(ctx) if err != nil { return nil, err } @@ -89,7 +89,7 @@ func (r *queryResolver) GetBuild(ctx context.Context, buildUUID uuid.UUID) (*ent func (r *queryResolver) GetTarget(ctx context.Context, instanceName, label, aspect, targetKind string) (*ent.Target, error) { // CollectFields here is used to avoid the N+1 query problem. Ent shouldn't // need it, but somehow it does. - query, err := r.client.Target.Query().Where( + query, err := r.db.Ent().Target.Query().Where( target.LabelEQ(label), target.Aspect(aspect), target.TargetKind(targetKind), diff --git a/internal/graphql/ent.resolvers.go b/internal/graphql/ent.resolvers.go index 01c7103d..5f755fd5 100644 --- a/internal/graphql/ent.resolvers.go +++ b/internal/graphql/ent.resolvers.go @@ -184,25 +184,25 @@ func (r *queryResolver) Nodes(ctx context.Context, ids []string) ([]ent.Noder, e // FindBazelInvocations is the resolver for the findBazelInvocations field. func (r *queryResolver) FindBazelInvocations(ctx context.Context, after *entgql.Cursor[int64], first *int, before *entgql.Cursor[int64], last *int, orderBy *ent.BazelInvocationOrder, where *ent.BazelInvocationWhereInput) (*ent.BazelInvocationConnection, error) { helpers.PaginationCursorsToUTC(after, before) - return r.client.BazelInvocation.Query().Paginate(ctx, after, first, before, last, ent.WithBazelInvocationFilter(where.Filter), ent.WithBazelInvocationOrder(orderBy)) + return r.db.Ent().BazelInvocation.Query().Paginate(ctx, after, first, before, last, ent.WithBazelInvocationFilter(where.Filter), ent.WithBazelInvocationOrder(orderBy)) } // FindBuilds is the resolver for the findBuilds field. func (r *queryResolver) FindBuilds(ctx context.Context, after *entgql.Cursor[int64], first *int, before *entgql.Cursor[int64], last *int, orderBy *ent.BuildOrder, where *ent.BuildWhereInput) (*ent.BuildConnection, error) { helpers.PaginationCursorsToUTC(after, before) - return r.client.Build.Query().Paginate(ctx, after, first, before, last, ent.WithBuildFilter(where.Filter), ent.WithBuildOrder(orderBy)) + return r.db.Ent().Build.Query().Paginate(ctx, after, first, before, last, ent.WithBuildFilter(where.Filter), ent.WithBuildOrder(orderBy)) } // FindTargets is the resolver for the findTargets field. func (r *queryResolver) FindTargets(ctx context.Context, after *entgql.Cursor[int64], first *int, before *entgql.Cursor[int64], last *int, where *ent.TargetWhereInput) (*ent.TargetConnection, error) { helpers.PaginationCursorsToUTC(after, before) - return r.client.Target.Query().Paginate(ctx, after, first, before, last, ent.WithTargetFilter(where.Filter)) + return r.db.Ent().Target.Query().Paginate(ctx, after, first, before, last, ent.WithTargetFilter(where.Filter)) } // FindTestSummaries is the resolver for the findTestSummaries field. func (r *queryResolver) FindTestSummaries(ctx context.Context, after *entgql.Cursor[int64], first *int, before *entgql.Cursor[int64], last *int, orderBy *ent.TestSummaryOrder, where *ent.TestSummaryWhereInput) (*ent.TestSummaryConnection, error) { helpers.PaginationCursorsToUTC(after, before) - return r.client.TestSummary.Query().Paginate(ctx, after, first, before, last, ent.WithTestSummaryFilter(where.Filter), ent.WithTestSummaryOrder(orderBy)) + return r.db.Ent().TestSummary.Query().Paginate(ctx, after, first, before, last, ent.WithTestSummaryFilter(where.Filter), ent.WithTestSummaryOrder(orderBy)) } // ID is the resolver for the id field. diff --git a/internal/graphql/resolver.go b/internal/graphql/resolver.go index 291f1bd7..695cb43b 100644 --- a/internal/graphql/resolver.go +++ b/internal/graphql/resolver.go @@ -2,8 +2,7 @@ package graphql import ( "github.com/99designs/gqlgen/graphql" - - "github.com/buildbarn/bb-portal/ent/gen/ent" + "github.com/buildbarn/bb-portal/internal/database" ) // This file will not be regenerated automatically. @@ -12,12 +11,12 @@ import ( // The Resolver Type for DI type Resolver struct { - client *ent.Client + db database.Client } // NewSchema creates a graphql executable schema. -func NewSchema(client *ent.Client) graphql.ExecutableSchema { +func NewSchema(db database.Client) graphql.ExecutableSchema { return NewExecutableSchema(Config{ - Resolvers: &Resolver{client: client}, + Resolvers: &Resolver{db: db}, }) } diff --git a/internal/graphql/server.go b/internal/graphql/server.go index 6e0aef07..6762c332 100644 --- a/internal/graphql/server.go +++ b/internal/graphql/server.go @@ -1,21 +1,19 @@ package graphql import ( - "entgo.io/contrib/entgql" "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler/extension" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/aereal/otelgqlgen" - "github.com/buildbarn/bb-portal/ent/gen/ent" + "github.com/buildbarn/bb-portal/internal/database" "go.opentelemetry.io/otel/trace" ) // NewGraphqlHandler creates a new GraphQL handler -func NewGraphqlHandler(dbClient *ent.Client, tracerProvider trace.TracerProvider) *handler.Server { +func NewGraphqlHandler(dbClient database.Client, tracerProvider trace.TracerProvider) *handler.Server { srv := handler.New(NewSchema(dbClient)) srv.AddTransport(transport.POST{}) srv.Use(extension.Introspection{}) - srv.Use(entgql.Transactioner{TxOpener: dbClient}) srv.Use(otelgqlgen.New(otelgqlgen.WithTracerProvider(tracerProvider))) // A fixed complexity limit for incoming GraphQL queries. // See https://gqlgen.com/master/reference/complexity/ for more details. diff --git a/test/integrationtest/BUILD.bazel b/test/integrationtest/BUILD.bazel index 9e5061ed..18cc8087 100644 --- a/test/integrationtest/BUILD.bazel +++ b/test/integrationtest/BUILD.bazel @@ -9,7 +9,6 @@ go_test( ], data = glob(["testdata/**"]) + ["//frontend/src/graphql:__generated__"], deps = [ - "//ent/gen/ent", "//ent/gen/ent/runtime", "//internal/api/http/bepuploader", "//internal/database", diff --git a/test/integrationtest/integration_test.go b/test/integrationtest/integration_test.go index 36ff6219..b252dce2 100644 --- a/test/integrationtest/integration_test.go +++ b/test/integrationtest/integration_test.go @@ -337,7 +337,7 @@ func runTestCase(t *testing.T, queryRegistry *testkit.QueryRegistry, testCase te }) } - graphqlServer := startGraphqlHTTPServer(t, db.Ent()) + graphqlServer := startGraphqlHTTPServer(t, db) runGraphqlTestCases(ctx, t, graphqlServer.URL, queryRegistry, testCase) } diff --git a/test/integrationtest/util_test.go b/test/integrationtest/util_test.go index a1664b49..375dd3a6 100644 --- a/test/integrationtest/util_test.go +++ b/test/integrationtest/util_test.go @@ -8,7 +8,6 @@ import ( "testing" gqlgen "github.com/99designs/gqlgen/graphql" - "github.com/buildbarn/bb-portal/ent/gen/ent" "github.com/buildbarn/bb-portal/internal/api/http/bepuploader" "github.com/buildbarn/bb-portal/internal/database" "github.com/buildbarn/bb-portal/internal/database/dbauthservice" @@ -50,8 +49,8 @@ func setupTestBepUploader(t *testing.T, db database.Client, testCase testCase) * return bepUploader } -func startGraphqlHTTPServer(t *testing.T, client *ent.Client) *httptest.Server { - srv := graphql.NewGraphqlHandler(client, trace.NewNoopTracerProvider()) +func startGraphqlHTTPServer(t *testing.T, db database.Client) *httptest.Server { + srv := graphql.NewGraphqlHandler(db, trace.NewNoopTracerProvider()) // Bypass DB auth service for integration tests. srv.AroundOperations(func(ctx context.Context, next gqlgen.OperationHandler) gqlgen.ResponseHandler {