From ef314b9031b4723d5c2e392502de8e0468500d25 Mon Sep 17 00:00:00 2001 From: Thomas Rosenstein Date: Sun, 14 Dec 2025 19:31:51 +0000 Subject: [PATCH] Improve sync responsiveness with background execution and context cancellation This change improves the responsiveness of the operator when handling deletion requests by running sync operations in the background and using context cancellation to interrupt stuck operations. Changes: - Add context field to Cluster struct, passed through New() - Add Cancel() method to cancel cluster's context - Add StartSync/EndSync/NeedsResync for managing background sync state - Run Sync() in a background goroutine so worker can process other events - Add context-aware DB connection methods (initDbConnWithContext) - Add RetryWithContext() that respects context cancellation - Cancel cluster context immediately when DeletionTimestamp detected - Use context-aware connections in syncRoles/syncDatabases - StartSync/NeedsResync check context cancellation to prevent new syncs during deletion (no need for separate deleted flag) Flow: 1. Sync event spawns background goroutine and returns immediately 2. If another sync arrives while one is running, needsResync flag is set 3. When sync completes, it checks needsResync and requeues if needed 4. Delete cancels context -> stuck DB operations return early -> mutex released 5. StartSync/NeedsResync return false when context cancelled 6. Delete proceeds without waiting for slow/stuck sync operations --- pkg/cluster/cluster.go | 68 +++++++++++++- pkg/cluster/cluster_test.go | 8 +- pkg/cluster/connection_pooler_test.go | 16 ++-- pkg/cluster/database.go | 39 ++++---- pkg/cluster/k8sres_test.go | 127 ++++++++++++++------------ pkg/cluster/pod_test.go | 5 +- pkg/cluster/streams_test.go | 8 +- pkg/cluster/sync.go | 16 ++-- pkg/cluster/sync_test.go | 12 +-- pkg/cluster/util_test.go | 6 +- pkg/cluster/volumes_test.go | 10 +- pkg/controller/postgresql.go | 52 +++++++++-- pkg/util/retryutil/retry_util.go | 19 +++- 13 files changed, 255 insertions(+), 131 deletions(-) diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 1c3ad5295..c00021743 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -90,6 +90,11 @@ type Cluster struct { podSubscribersMu sync.RWMutex pgDb *sql.DB mu sync.Mutex + ctx context.Context + cancelFunc context.CancelFunc + syncMu sync.Mutex // protects syncRunning and needsResync + syncRunning bool + needsResync bool userSyncStrategy spec.UserSyncer deleteOptions metav1.DeleteOptions podEventsQueue *cache.FIFO @@ -122,9 +127,12 @@ type compareLogicalBackupJobResult struct { } // New creates a new cluster. This function should be called from a controller. -func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgresql, logger *logrus.Entry, eventRecorder record.EventRecorder) *Cluster { +func New(ctx context.Context, cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgresql, logger *logrus.Entry, eventRecorder record.EventRecorder) *Cluster { deletePropagationPolicy := metav1.DeletePropagationOrphan + // Create a cancellable context for this cluster + clusterCtx, cancelFunc := context.WithCancel(ctx) + podEventsQueue := cache.NewFIFO(func(obj interface{}) (string, error) { e, ok := obj.(PodEvent) if !ok { @@ -139,6 +147,8 @@ func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgres } cluster := &Cluster{ + ctx: clusterCtx, + cancelFunc: cancelFunc, Config: cfg, Postgresql: pgSpec, pgUsers: make(map[string]spec.PgUser), @@ -178,6 +188,62 @@ func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgres return cluster } +// Cancel cancels the cluster's context, which will cause any ongoing +// context-aware operations (like Sync) to return early. +func (c *Cluster) Cancel() { + if c.cancelFunc != nil { + c.cancelFunc() + } +} + +// StartSync attempts to start a sync operation. Returns true if sync can start +// (no sync currently running and context not cancelled). Returns false if a sync +// is already running (needsResync is set) or if context is cancelled (deletion in progress). +func (c *Cluster) StartSync() bool { + c.syncMu.Lock() + defer c.syncMu.Unlock() + + // Check if context is cancelled (deletion in progress) + select { + case <-c.ctx.Done(): + return false + default: + } + + if c.syncRunning { + c.needsResync = true + return false + } + c.syncRunning = true + c.needsResync = false + return true +} + +// EndSync marks the sync operation as complete. +func (c *Cluster) EndSync() { + c.syncMu.Lock() + defer c.syncMu.Unlock() + c.syncRunning = false +} + +// NeedsResync returns true if a resync was requested while sync was running, +// and clears the flag. Returns false if context is cancelled (deletion in progress). +func (c *Cluster) NeedsResync() bool { + c.syncMu.Lock() + defer c.syncMu.Unlock() + + // Check if context is cancelled (deletion in progress) + select { + case <-c.ctx.Done(): + return false + default: + } + + result := c.needsResync + c.needsResync = false + return result +} + func (c *Cluster) clusterName() spec.NamespacedName { return util.NameFromMeta(c.ObjectMeta) } diff --git a/pkg/cluster/cluster_test.go b/pkg/cluster/cluster_test.go index 56b9640ef..e944fa965 100644 --- a/pkg/cluster/cluster_test.go +++ b/pkg/cluster/cluster_test.go @@ -43,7 +43,7 @@ var logger = logrus.New().WithField("test", "cluster") // 1 cluster, primary endpoint, 2 services, the secrets, the statefulset and pods being ready var eventRecorder = record.NewFakeRecorder(7) -var cl = New( +var cl = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -135,7 +135,7 @@ func TestCreate(t *testing.T) { client.Postgresqls(clusterNamespace).Create(context.TODO(), &pg, metav1.CreateOptions{}) client.Pods(clusterNamespace).Create(context.TODO(), &pod, metav1.CreateOptions{}) - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1607,7 +1607,7 @@ func TestCompareLogicalBackupJob(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1756,7 +1756,7 @@ func TestCrossNamespacedSecrets(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ConnectionPooler: config.ConnectionPooler{ diff --git a/pkg/cluster/connection_pooler_test.go b/pkg/cluster/connection_pooler_test.go index 23213520f..d15a18139 100644 --- a/pkg/cluster/connection_pooler_test.go +++ b/pkg/cluster/connection_pooler_test.go @@ -162,7 +162,7 @@ func noEmptySync(cluster *Cluster, err error, reason SyncReason) error { func TestNeedConnectionPooler(t *testing.T) { testName := "Test how connection pooler can be enabled" - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -298,7 +298,7 @@ func TestConnectionPoolerCreateDeletion(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ConnectionPooler: config.ConnectionPooler{ @@ -406,7 +406,7 @@ func TestConnectionPoolerSync(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ConnectionPooler: config.ConnectionPooler{ @@ -668,7 +668,7 @@ func TestConnectionPoolerSync(t *testing.T) { func TestConnectionPoolerPodSpec(t *testing.T) { testName := "Test connection pooler pod template generation" - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -691,7 +691,7 @@ func TestConnectionPoolerPodSpec(t *testing.T) { ConnectionPooler: &acidv1.ConnectionPooler{}, EnableReplicaConnectionPooler: boolToPointer(true), } - var clusterNoDefaultRes = New( + var clusterNoDefaultRes = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -781,7 +781,7 @@ func TestConnectionPoolerPodSpec(t *testing.T) { func TestConnectionPoolerDeploymentSpec(t *testing.T) { testName := "Test connection pooler deployment spec generation" - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -985,7 +985,7 @@ func TestPoolerTLS(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1066,7 +1066,7 @@ func TestPoolerTLS(t *testing.T) { func TestConnectionPoolerServiceSpec(t *testing.T) { testName := "Test connection pooler service spec generation" - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, diff --git a/pkg/cluster/database.go b/pkg/cluster/database.go index 56b5f3638..4f3d5c775 100644 --- a/pkg/cluster/database.go +++ b/pkg/cluster/database.go @@ -2,6 +2,7 @@ package cluster import ( "bytes" + "context" "database/sql" "fmt" "net" @@ -121,26 +122,31 @@ func (c *Cluster) initDbConn() error { if c.pgDb != nil { return nil } - return c.initDbConnWithName("") } -// Worker function for connection initialization. This function does not check -// if the connection is already open, if it is then it will be overwritten. -// Callers need to make sure no connection is open, otherwise we could leak -// connections +// initDbConnWithName initializes a database connection using the cluster's context. +// This function does not check if the connection is already open. func (c *Cluster) initDbConnWithName(dbname string) error { + return c.initDbConnWithNameContext(c.ctx, dbname) +} + +// initDbConnWithNameContext initializes a database connection with an explicit context. +// Use this when you need a custom context (e.g., different timeout, or context.Background() +// for operations that should not be cancelled). This function does not check if the +// connection is already open, callers need to ensure no connection is open to avoid leaks. +func (c *Cluster) initDbConnWithNameContext(ctx context.Context, dbname string) error { c.setProcessName("initializing db connection") var conn *sql.DB connstring := c.pgConnectionString(dbname) - finalerr := retryutil.Retry(constants.PostgresConnectTimeout, constants.PostgresConnectRetryTimeout, + finalerr := retryutil.RetryWithContext(ctx, constants.PostgresConnectTimeout, constants.PostgresConnectRetryTimeout, func() (bool, error) { var err error conn, err = sql.Open("postgres", connstring) if err == nil { - err = conn.Ping() + err = conn.PingContext(ctx) } if err == nil { @@ -268,9 +274,7 @@ func findUsersFromRotation(rotatedUsers []string, db *sql.DB) (map[string]string }() for rows.Next() { - var ( - rolname, roldatesuffix string - ) + var rolname, roldatesuffix string err := rows.Scan(&rolname, &roldatesuffix) if err != nil { return nil, fmt.Errorf("error when processing rows of deprecated users: %v", err) @@ -331,9 +335,7 @@ func (c *Cluster) cleanupRotatedUsers(rotatedUsers []string) error { // getDatabases returns the map of current databases with owners // The caller is responsible for opening and closing the database connection func (c *Cluster) getDatabases() (dbs map[string]string, err error) { - var ( - rows *sql.Rows - ) + var rows *sql.Rows if rows, err = c.pgDb.Query(getDatabasesSQL); err != nil { return nil, fmt.Errorf("could not query database: %v", err) @@ -551,9 +553,7 @@ func (c *Cluster) getOwnerRoles(dbObjPath string, withUser bool) (owners []strin // getExtension returns the list of current database extensions // The caller is responsible for opening and closing the database connection func (c *Cluster) getExtensions() (dbExtensions map[string]string, err error) { - var ( - rows *sql.Rows - ) + var rows *sql.Rows if rows, err = c.pgDb.Query(getExtensionsSQL); err != nil { return nil, fmt.Errorf("could not query database extensions: %v", err) @@ -598,7 +598,6 @@ func (c *Cluster) executeAlterExtension(extName, schemaName string) error { } func (c *Cluster) execCreateOrAlterExtension(extName, schemaName, statement, doing, operation string) error { - c.logger.Infof("%s %q schema %q", doing, extName, schemaName) if _, err := c.pgDb.Exec(fmt.Sprintf(statement, extName, schemaName)); err != nil { return fmt.Errorf("could not execute %s: %v", operation, err) @@ -610,9 +609,7 @@ func (c *Cluster) execCreateOrAlterExtension(extName, schemaName, statement, doi // getPublications returns the list of current database publications with tables // The caller is responsible for opening and closing the database connection func (c *Cluster) getPublications() (publications map[string]string, err error) { - var ( - rows *sql.Rows - ) + var rows *sql.Rows if rows, err = c.pgDb.Query(getPublicationsSQL); err != nil { return nil, fmt.Errorf("could not query database publications: %v", err) @@ -668,7 +665,6 @@ func (c *Cluster) executeAlterPublication(pubName, tableList string) error { } func (c *Cluster) execCreateOrAlterPublication(pubName, tableList, statement, doing, operation string) error { - c.logger.Debugf("%s %q with table list %q", doing, pubName, tableList) if _, err := c.pgDb.Exec(fmt.Sprintf(statement, pubName, tableList)); err != nil { return fmt.Errorf("could not execute %s: %v", operation, err) @@ -743,7 +739,6 @@ func (c *Cluster) installLookupFunction(poolerSchema, poolerUser string) error { constants.PostgresConnectTimeout, constants.PostgresConnectRetryTimeout, func() (bool, error) { - // At this moment we are not connected to any database if err := c.initDbConnWithName(dbname); err != nil { msg := "could not init database connection to %s" diff --git a/pkg/cluster/k8sres_test.go b/pkg/cluster/k8sres_test.go index 9226c27ac..daefbc03d 100644 --- a/pkg/cluster/k8sres_test.go +++ b/pkg/cluster/k8sres_test.go @@ -52,7 +52,7 @@ type ExpectedValue struct { } func TestGenerateSpiloJSONConfiguration(t *testing.T) { - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -242,11 +242,9 @@ func (c *mockConfigMap) Get(ctx context.Context, name string, options metav1.Get return configmap, nil } -type MockSecretGetter struct { -} +type MockSecretGetter struct{} -type MockConfigMapsGetter struct { -} +type MockConfigMapsGetter struct{} func (c *MockSecretGetter) Secrets(namespace string) v1core.SecretInterface { return &mockSecret{} @@ -262,6 +260,7 @@ func newMockKubernetesClient() k8sutil.KubernetesClient { ConfigMapsGetter: &MockConfigMapsGetter{}, } } + func newMockCluster(opConfig config.Config) *Cluster { cluster := &Cluster{ Config: Config{OpConfig: opConfig}, @@ -445,7 +444,6 @@ func TestPodEnvironmentSecretVariables(t *testing.T) { } } } - } // Test if the keys of an existing secret are properly referenced @@ -535,7 +533,6 @@ func TestCronjobEnvironmentSecretVariables(t *testing.T) { } } } - } func testEnvs(cluster *Cluster, podSpec *v1.PodTemplateSpec, role PostgresRole) error { @@ -565,7 +562,7 @@ func testEnvs(cluster *Cluster, podSpec *v1.PodTemplateSpec, role PostgresRole) } func TestGenerateSpiloPodEnvVars(t *testing.T) { - var dummyUUID = "efd12e58-5786-11e8-b5a7-06148230260c" + dummyUUID := "efd12e58-5786-11e8-b5a7-06148230260c" expectedClusterNameLabel := []ExpectedValue{ { @@ -1143,7 +1140,7 @@ func TestGetNumberOfInstances(t *testing.T) { } for _, tt := range tests { - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: tt.config, }, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder) @@ -1210,7 +1207,7 @@ func TestCloneEnv(t *testing.T) { }, } - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ WALES3Bucket: "wale-bucket", @@ -1423,7 +1420,7 @@ func TestStandbyEnv(t *testing.T) { }, } - var cluster = New( + cluster := New(context.Background(), Config{}, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder) for _, tt := range tests { @@ -1452,9 +1449,9 @@ func TestNodeAffinity(t *testing.T) { var err error var spec acidv1.PostgresSpec var cluster *Cluster - var spiloRunAsUser = int64(101) - var spiloRunAsGroup = int64(103) - var spiloFSGroup = int64(103) + spiloRunAsUser := int64(101) + spiloRunAsGroup := int64(103) + spiloFSGroup := int64(103) makeSpec := func(nodeAffinity *v1.NodeAffinity) acidv1.PostgresSpec { return acidv1.PostgresSpec{ @@ -1470,7 +1467,7 @@ func TestNodeAffinity(t *testing.T) { } } - cluster = New( + cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1563,7 +1560,7 @@ func TestPodAffinity(t *testing.T) { } for _, tt := range tests { - cluster := New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ EnablePodAntiAffinity: tt.anti, @@ -1692,7 +1689,7 @@ func TestTLS(t *testing.T) { spiloRunAsUser := int64(101) spiloRunAsGroup := int64(103) spiloFSGroup := int64(103) - defaultMode := int32(0640) + defaultMode := int32(0o640) mountPath := "/tls" pg := acidv1.Postgresql{ @@ -1710,7 +1707,8 @@ func TestTLS(t *testing.T) { Size: "1G", }, TLS: &acidv1.TLSDescription{ - SecretName: tlsSecretName, CAFile: "ca.crt"}, + SecretName: tlsSecretName, CAFile: "ca.crt", + }, AdditionalVolumes: []acidv1.AdditionalVolume{ { Name: tlsSecretName, @@ -1726,7 +1724,7 @@ func TestTLS(t *testing.T) { }, } - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -1981,7 +1979,7 @@ func TestAdditionalVolume(t *testing.T) { }, } - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -2122,7 +2120,7 @@ func TestVolumeSelector(t *testing.T) { }, } - cluster := New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -2221,7 +2219,7 @@ func TestSidecars(t *testing.T) { }, } - cluster = New( + cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -2367,7 +2365,6 @@ func TestSidecars(t *testing.T) { Env: scalyrEnv, VolumeMounts: mounts, }) - } func TestContainerValidation(t *testing.T) { @@ -2530,7 +2527,7 @@ func TestContainerValidation(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - cluster := New(tc.clusterConfig, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder) + cluster := New(context.Background(), tc.clusterConfig, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder) _, err := cluster.generateStatefulSet(&tc.spec) @@ -2581,13 +2578,15 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { if !masterLabelSelectorDisabled { if isPrimary { expectedLabels := &metav1.LabelSelector{ - MatchLabels: map[string]string{"spilo-role": "master", "cluster-name": "myapp-database"}} + MatchLabels: map[string]string{"spilo-role": "master", "cluster-name": "myapp-database"}, + } if !reflect.DeepEqual(podDisruptionBudget.Spec.Selector, expectedLabels) { return fmt.Errorf("MatchLabels incorrect, got %#v, expected %#v", podDisruptionBudget.Spec.Selector, expectedLabels) } } else { expectedLabels := &metav1.LabelSelector{ - MatchLabels: map[string]string{"cluster-name": "myapp-database", "critical-operation": "true"}} + MatchLabels: map[string]string{"cluster-name": "myapp-database", "critical-operation": "true"}, + } if !reflect.DeepEqual(podDisruptionBudget.Spec.Selector, expectedLabels) { return fmt.Errorf("MatchLabels incorrect, got %#v, expected %#v", podDisruptionBudget.Spec.Selector, expectedLabels) } @@ -2619,12 +2618,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }{ { scenario: "With multiple instances", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb"}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2636,12 +2636,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With zero instances", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb"}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 0}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 0}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2653,12 +2654,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With PodDisruptionBudget disabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.False()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2670,12 +2672,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With non-default PDBNameFormat and PodDisruptionBudget explicitly enabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-databass-budget", EnablePodDisruptionBudget: util.True()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2687,12 +2690,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With PDBMasterLabelSelector disabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.True(), PDBMasterLabelSelector: util.False()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2704,12 +2708,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With OwnerReference enabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role", EnableOwnerReferences: util.True()}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.True()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2739,12 +2744,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }{ { scenario: "With multiple instances", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb"}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2756,12 +2762,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With zero instances", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb"}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 0}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 0}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2773,12 +2780,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With PodDisruptionBudget disabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role"}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.False()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2790,12 +2798,13 @@ func TestGeneratePodDisruptionBudget(t *testing.T) { }, { scenario: "With OwnerReference enabled", - spec: New( + spec: New(context.Background(), Config{OpConfig: config.Config{Resources: config.Resources{ClusterNameLabel: "cluster-name", PodRoleLabel: "spilo-role", EnableOwnerReferences: util.True()}, PDBNameFormat: "postgres-{cluster}-pdb", EnablePodDisruptionBudget: util.True()}}, k8sutil.KubernetesClient{}, acidv1.Postgresql{ ObjectMeta: metav1.ObjectMeta{Name: "myapp-database", Namespace: "myapp"}, - Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}}, + Spec: acidv1.PostgresSpec{TeamID: "myapp", NumberOfInstances: 3}, + }, logger, eventRecorder), check: []func(cluster *Cluster, podDisruptionBudget *policyv1.PodDisruptionBudget) error{ @@ -2852,7 +2861,7 @@ func TestGenerateService(t *testing.T) { EnableMasterLoadBalancer: &enableLB, } - cluster = New( + cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -2899,11 +2908,10 @@ func TestGenerateService(t *testing.T) { cluster.OpConfig.ExternalTrafficPolicy = "Local" service = cluster.generateService(Master, &spec) assert.Equal(t, v1.ServiceExternalTrafficPolicyTypeLocal, service.Spec.ExternalTrafficPolicy) - } func TestCreateLoadBalancerLogic(t *testing.T) { - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ ProtectedRoles: []string{"admin"}, @@ -3116,7 +3124,7 @@ func TestEnableLoadBalancers(t *testing.T) { } for _, tt := range tests { - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: tt.config, }, client, tt.pgSpec, logger, eventRecorder) @@ -3811,7 +3819,7 @@ func TestGenerateResourceRequirements(t *testing.T) { } for _, tt := range tests { - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: tt.config, }, client, tt.pgSpec, logger, newEventRecorder) @@ -3996,7 +4004,7 @@ func TestGenerateLogicalBackupJob(t *testing.T) { } for _, tt := range tests { - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: tt.config, }, k8sutil.NewMockKubernetesClient(), acidv1.Postgresql{}, logger, eventRecorder) @@ -4305,7 +4313,7 @@ func TestTopologySpreadConstraints(t *testing.T) { }, } - cluster := New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -4317,11 +4325,12 @@ func TestTopologySpreadConstraints(t *testing.T) { s, err := cluster.generateStatefulSet(&pg.Spec) assert.NoError(t, err) - assert.Contains(t, s.Spec.Template.Spec.TopologySpreadConstraints, v1.TopologySpreadConstraint{ - MaxSkew: int32(1), - TopologyKey: "topology.kubernetes.io/zone", - WhenUnsatisfiable: v1.DoNotSchedule, - LabelSelector: labelSelector, - }, + assert.Contains( + t, s.Spec.Template.Spec.TopologySpreadConstraints, v1.TopologySpreadConstraint{ + MaxSkew: int32(1), + TopologyKey: "topology.kubernetes.io/zone", + WhenUnsatisfiable: v1.DoNotSchedule, + LabelSelector: labelSelector, + }, ) } diff --git a/pkg/cluster/pod_test.go b/pkg/cluster/pod_test.go index 6ab3f9207..73817fffb 100644 --- a/pkg/cluster/pod_test.go +++ b/pkg/cluster/pod_test.go @@ -2,6 +2,7 @@ package cluster import ( "bytes" + "context" "fmt" "io" "net/http" @@ -25,7 +26,7 @@ func TestGetSwitchoverCandidate(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ PatroniAPICheckInterval: time.Duration(1), @@ -293,7 +294,7 @@ func TestPodIsNotRunning(t *testing.T) { func TestAllPodsRunning(t *testing.T) { client, _ := newFakeK8sSyncClient() - var cluster = New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ diff --git a/pkg/cluster/streams_test.go b/pkg/cluster/streams_test.go index 934f2bfd4..86f26eea5 100644 --- a/pkg/cluster/streams_test.go +++ b/pkg/cluster/streams_test.go @@ -223,7 +223,7 @@ var ( }, } - cluster = New( + cluster = New(context.Background(), Config{ OpConfig: config.Config{ Auth: config.Auth{ @@ -529,7 +529,7 @@ func newFabricEventStream(streams []zalandov1.EventStream, annotations map[strin func TestSyncStreams(t *testing.T) { newClusterName := fmt.Sprintf("%s-2", pg.Name) pg.Name = newClusterName - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -688,7 +688,7 @@ func TestSameStreams(t *testing.T) { func TestUpdateStreams(t *testing.T) { pg.Name = fmt.Sprintf("%s-3", pg.Name) - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -787,7 +787,7 @@ func patchPostgresqlStreams(t *testing.T, cluster *Cluster, pgSpec *acidv1.Postg func TestDeleteStreams(t *testing.T) { pg.Name = fmt.Sprintf("%s-4", pg.Name) - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", diff --git a/pkg/cluster/sync.go b/pkg/cluster/sync.go index 7c478477a..6a57a1d14 100644 --- a/pkg/cluster/sync.go +++ b/pkg/cluster/sync.go @@ -76,7 +76,7 @@ func (c *Cluster) Sync(newSpec *acidv1.Postgresql) error { return err } - //TODO: mind the secrets of the deleted/new users + // TODO: mind the secrets of the deleted/new users if err = c.syncSecrets(); err != nil { err = fmt.Errorf("could not sync secrets: %v", err) return err @@ -869,7 +869,6 @@ func (c *Cluster) restartInstance(pod *v1.Pod, restartWait uint32) error { // AnnotationsToPropagate get the annotations to update if required // based on the annotations in postgres CRD func (c *Cluster) AnnotationsToPropagate(annotations map[string]string) map[string]string { - if annotations == nil { annotations = make(map[string]string) } @@ -1140,7 +1139,8 @@ func (c *Cluster) updateSecret( secretUsername string, generatedSecret *v1.Secret, retentionUsers *[]string, - currentTime time.Time) (*v1.Secret, error) { + currentTime time.Time, +) (*v1.Secret, error) { var ( secret *v1.Secret err error @@ -1296,7 +1296,8 @@ func (c *Cluster) rotatePasswordInSecret( secretUsername string, roleOrigin spec.RoleOrigin, currentTime time.Time, - retentionUsers *[]string) (string, error) { + retentionUsers *[]string, +) (string, error) { var ( err error nextRotationDate time.Time @@ -1521,7 +1522,7 @@ func (c *Cluster) syncDatabases() error { preparedDatabases := make([]string, 0) if err := c.initDbConn(); err != nil { - return fmt.Errorf("could not init database connection") + return fmt.Errorf("could not init database connection: %v", err) } defer func() { if err := c.closeDbConn(); err != nil { @@ -1605,7 +1606,7 @@ func (c *Cluster) syncPreparedDatabases() error { errors := make([]string, 0) for preparedDbName, preparedDB := range c.Spec.PreparedDatabases { - if err := c.initDbConnWithName(preparedDbName); err != nil { + if err := c.initDbConnWithNameContext(c.ctx, preparedDbName); err != nil { errors = append(errors, fmt.Sprintf("could not init connection to database %s: %v", preparedDbName, err)) continue } @@ -1749,7 +1750,8 @@ func (c *Cluster) syncLogicalBackupJob() error { } if len(cmp.deletedPodAnnotations) != 0 { templateMetadataReq := map[string]map[string]map[string]map[string]map[string]map[string]map[string]*string{ - "spec": {"jobTemplate": {"spec": {"template": {"metadata": {"annotations": {}}}}}}} + "spec": {"jobTemplate": {"spec": {"template": {"metadata": {"annotations": {}}}}}}, + } for _, anno := range cmp.deletedPodAnnotations { templateMetadataReq["spec"]["jobTemplate"]["spec"]["template"]["metadata"]["annotations"][anno] = nil } diff --git a/pkg/cluster/sync_test.go b/pkg/cluster/sync_test.go index f7d46d427..ef7ee52d1 100644 --- a/pkg/cluster/sync_test.go +++ b/pkg/cluster/sync_test.go @@ -88,7 +88,7 @@ func TestSyncStatefulSetsAnnotations(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -184,7 +184,7 @@ func TestPodAnnotationsSync(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PatroniAPICheckInterval: time.Duration(1), @@ -369,7 +369,7 @@ func TestCheckAndSetGlobalPostgreSQLConfiguration(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PodManagementPolicy: "ordered_ready", @@ -691,7 +691,7 @@ func TestSyncStandbyClusterConfiguration(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ PatroniAPICheckInterval: time.Duration(1), @@ -879,7 +879,7 @@ func TestUpdateSecret(t *testing.T) { } // new cluster with enabled password rotation - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Auth: config.Auth{ @@ -1023,7 +1023,7 @@ func TestUpdateSecretNameConflict(t *testing.T) { }, } - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Auth: config.Auth{ diff --git a/pkg/cluster/util_test.go b/pkg/cluster/util_test.go index 8413ca396..f90823d65 100644 --- a/pkg/cluster/util_test.go +++ b/pkg/cluster/util_test.go @@ -294,7 +294,7 @@ func newInheritedAnnotationsCluster(client k8sutil.KubernetesClient) (*Cluster, return nil, err } - cluster := New( + cluster := New(context.Background(), Config{ OpConfig: config.Config{ PatroniAPICheckInterval: time.Duration(1), @@ -658,6 +658,7 @@ func Test_trimCronjobName(t *testing.T) { func TestIsInMaintenanceWindow(t *testing.T) { cluster := New( + context.Background(), Config{ OpConfig: config.Config{ EnableMaintenanceWindows: util.True(), @@ -670,7 +671,8 @@ func TestIsInMaintenanceWindow(t *testing.T) { DefaultMemoryLimit: "300Mi", }, }, - }, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder) + }, k8sutil.KubernetesClient{}, acidv1.Postgresql{}, logger, eventRecorder, + ) cluster.Name = clusterName cluster.Namespace = namespace diff --git a/pkg/cluster/volumes_test.go b/pkg/cluster/volumes_test.go index 95ecc7624..4d9eb0189 100644 --- a/pkg/cluster/volumes_test.go +++ b/pkg/cluster/volumes_test.go @@ -59,7 +59,7 @@ func TestResizeVolumeClaim(t *testing.T) { assert.NoError(t, err) // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ @@ -185,7 +185,7 @@ func TestMigrateEBS(t *testing.T) { namespace := "default" // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ @@ -293,7 +293,7 @@ func TestMigrateGp3Support(t *testing.T) { namespace := "default" // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ @@ -355,7 +355,7 @@ func TestManualGp2Gp3Support(t *testing.T) { namespace := "default" // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ @@ -415,7 +415,7 @@ func TestDontTouchType(t *testing.T) { namespace := "default" // new cluster with pvc storage resize mode and configured labels - var cluster = New( + var cluster = New(context.Background(), Config{ OpConfig: config.Config{ Resources: config.Resources{ diff --git a/pkg/controller/postgresql.go b/pkg/controller/postgresql.go index ab5e0d772..4c0621d1d 100644 --- a/pkg/controller/postgresql.go +++ b/pkg/controller/postgresql.go @@ -167,7 +167,7 @@ func (c *Controller) addCluster(lg *logrus.Entry, clusterName spec.NamespacedNam } } - cl := cluster.New(c.makeClusterConfig(), c.KubeClient, *pgSpec, lg, c.eventRecorder) + cl := cluster.New(context.Background(), c.makeClusterConfig(), c.KubeClient, *pgSpec, lg, c.eventRecorder) cl.Run(c.stopCh) teamName := strings.ToLower(cl.Spec.TeamID) @@ -268,6 +268,7 @@ func (c *Controller) processEvent(event ClusterEvent) { // Check if this cluster has been marked for deletion if !event.NewSpec.ObjectMeta.DeletionTimestamp.IsZero() { lg.Infof("cluster has a DeletionTimestamp of %s, starting deletion now.", event.NewSpec.ObjectMeta.DeletionTimestamp.Format(time.RFC3339)) + cl.Cancel() // Cancel any ongoing operations if err = cl.Delete(); err != nil { cl.Error = fmt.Sprintf("error deleting cluster and its resources: %v", err) c.eventRecorder.Eventf(cl.GetReference(), v1.EventTypeWarning, "Delete", "%v", cl.Error) @@ -278,7 +279,6 @@ func (c *Controller) processEvent(event ClusterEvent) { return } - lg.Infoln("update of the cluster started") err = cl.Update(event.OldSpec, event.NewSpec) if err != nil { cl.Error = fmt.Sprintf("could not update cluster: %v", err) @@ -306,6 +306,7 @@ func (c *Controller) processEvent(event ClusterEvent) { // when using finalizers the deletion already happened if c.opConfig.EnableFinalizers == nil || !*c.opConfig.EnableFinalizers { lg.Infoln("deletion of the cluster started") + cl.Cancel() // Cancel any ongoing operations if err := cl.Delete(); err != nil { cl.Error = fmt.Sprintf("could not delete cluster: %v", err) c.eventRecorder.Eventf(cl.GetReference(), v1.EventTypeWarning, "Delete", "%v", cl.Error) @@ -331,8 +332,6 @@ func (c *Controller) processEvent(event ClusterEvent) { lg.Infof("cluster has been deleted") case EventSync: - lg.Infof("syncing of the cluster started") - // no race condition because a cluster is always processed by single worker if !clusterFound { cl, err = c.addCluster(lg, clusterName, event.NewSpec) @@ -347,22 +346,42 @@ func (c *Controller) processEvent(event ClusterEvent) { // has this cluster been marked as deleted already, then we shall start cleaning up if !cl.ObjectMeta.DeletionTimestamp.IsZero() { lg.Infof("cluster has a DeletionTimestamp of %s, starting deletion now.", cl.ObjectMeta.DeletionTimestamp.Format(time.RFC3339)) + cl.Cancel() // Cancel any ongoing operations if err = cl.Delete(); err != nil { cl.Error = fmt.Sprintf("error deleting cluster and its resources: %v", err) c.eventRecorder.Eventf(cl.GetReference(), v1.EventTypeWarning, "Delete", "%v", cl.Error) lg.Error(cl.Error) return } - } else { - if err = cl.Sync(event.NewSpec); err != nil { + return + } + + // Try to start sync - returns false if sync already running or cluster deleted + if !cl.StartSync() { + lg.Infof("sync already in progress, will resync when current sync completes") + return + } + + // Run sync in background goroutine so we can process other events (like delete) + lg.Infof("syncing of the cluster started (background)") + go func() { + defer cl.EndSync() + + if err := cl.Sync(event.NewSpec); err != nil { cl.Error = fmt.Sprintf("could not sync cluster: %v", err) c.eventRecorder.Eventf(cl.GetReference(), v1.EventTypeWarning, "Sync", "%v", cl.Error) lg.Error(cl.Error) return } + cl.Error = "" lg.Infof("cluster has been synced") - } - cl.Error = "" + + // Check if resync was requested while we were syncing + if cl.NeedsResync() { + lg.Infof("resync requested, queueing new sync event") + c.queueClusterEvent(nil, event.NewSpec, EventSync) + } + }() } } @@ -478,6 +497,19 @@ func (c *Controller) queueClusterEvent(informerOldSpec, informerNewSpec *acidv1. } } + // If the cluster is marked for deletion, cancel any ongoing operations immediately + // This unblocks stuck Sync operations so the delete can proceed + if informerNewSpec != nil && !informerNewSpec.ObjectMeta.DeletionTimestamp.IsZero() { + c.clustersMu.RLock() + if cl, found := c.clusters[clusterName]; found { + c.logger.WithField("cluster-name", clusterName).Infof( + "cluster marked for deletion (DeletionTimestamp: %s), cancelling ongoing operations", + informerNewSpec.ObjectMeta.DeletionTimestamp.Format(time.RFC3339)) + cl.Cancel() + } + c.clustersMu.RUnlock() + } + if clusterError != "" && eventType != EventDelete { c.logger.WithField("cluster-name", clusterName).Debugf("skipping %q event for the invalid cluster: %s", eventType, clusterError) @@ -579,9 +611,13 @@ func (c *Controller) postgresqlUpdate(prev, cur interface{}) { // Avoid the infinite recursion for status updates if reflect.DeepEqual(pgOld.Spec, pgNew.Spec) { if reflect.DeepEqual(pgNew.Annotations, pgOld.Annotations) { + c.logger.WithField("cluster-name", clusterName).Debugf( + "UPDATE event: no spec/annotation changes, skipping") return } } + + c.logger.WithField("cluster-name", clusterName).Infof("UPDATE event: spec or annotations changed, queueing event") c.queueClusterEvent(pgOld, pgNew, EventUpdate) } } diff --git a/pkg/util/retryutil/retry_util.go b/pkg/util/retryutil/retry_util.go index 868ba6e98..b5fab3b47 100644 --- a/pkg/util/retryutil/retry_util.go +++ b/pkg/util/retryutil/retry_util.go @@ -1,6 +1,7 @@ package retryutil import ( + "context" "fmt" "time" ) @@ -25,7 +26,7 @@ func (t *Ticker) Tick() { <-t.ticker.C } // Retry is a wrapper around RetryWorker that provides a real RetryTicker func Retry(interval time.Duration, timeout time.Duration, f func() (bool, error)) error { - //TODO: make the retry exponential + // TODO: make the retry exponential if timeout < interval { return fmt.Errorf("timeout(%s) should be greater than interval(%v)", timeout, interval) } @@ -33,6 +34,18 @@ func Retry(interval time.Duration, timeout time.Duration, f func() (bool, error) return RetryWorker(interval, timeout, tick, f) } +// RetryWithContext is like Retry but checks for context cancellation before each attempt. +func RetryWithContext(ctx context.Context, interval time.Duration, timeout time.Duration, f func() (bool, error)) error { + return Retry(interval, timeout, func() (bool, error) { + select { + case <-ctx.Done(): + return false, ctx.Err() + default: + return f() + } + }) +} + // RetryWorker calls ConditionFunc until either: // * it returns boolean true // * a timeout expires @@ -41,8 +54,8 @@ func RetryWorker( interval time.Duration, timeout time.Duration, tick RetryTicker, - f func() (bool, error)) error { - + f func() (bool, error), +) error { maxRetries := int(timeout / interval) defer tick.Stop()