Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion pkg/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ type Cluster struct {
podSubscribersMu sync.RWMutex
pgDb *sql.DB
mu sync.Mutex
ctx context.Context
cancelFunc context.CancelFunc
syncMu sync.Mutex // protects syncRunning and needsResync
syncRunning bool
needsResync bool
userSyncStrategy spec.UserSyncer
deleteOptions metav1.DeleteOptions
podEventsQueue *cache.FIFO
Expand Down Expand Up @@ -122,9 +127,12 @@ type compareLogicalBackupJobResult struct {
}

// New creates a new cluster. This function should be called from a controller.
func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgresql, logger *logrus.Entry, eventRecorder record.EventRecorder) *Cluster {
func New(ctx context.Context, cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgresql, logger *logrus.Entry, eventRecorder record.EventRecorder) *Cluster {
deletePropagationPolicy := metav1.DeletePropagationOrphan

// Create a cancellable context for this cluster
clusterCtx, cancelFunc := context.WithCancel(ctx)

podEventsQueue := cache.NewFIFO(func(obj interface{}) (string, error) {
e, ok := obj.(PodEvent)
if !ok {
Expand All @@ -139,6 +147,8 @@ func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgres
}

cluster := &Cluster{
ctx: clusterCtx,
cancelFunc: cancelFunc,
Config: cfg,
Postgresql: pgSpec,
pgUsers: make(map[string]spec.PgUser),
Expand Down Expand Up @@ -178,6 +188,62 @@ func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec acidv1.Postgres
return cluster
}

// Cancel cancels the cluster's context, which will cause any ongoing
// context-aware operations (like Sync) to return early.
func (c *Cluster) Cancel() {
if c.cancelFunc != nil {
c.cancelFunc()
}
}

// StartSync attempts to start a sync operation. Returns true if sync can start
// (no sync currently running and context not cancelled). Returns false if a sync
// is already running (needsResync is set) or if context is cancelled (deletion in progress).
func (c *Cluster) StartSync() bool {
c.syncMu.Lock()
defer c.syncMu.Unlock()

// Check if context is cancelled (deletion in progress)
select {
case <-c.ctx.Done():
return false
default:
}

if c.syncRunning {
c.needsResync = true
return false
}
c.syncRunning = true
c.needsResync = false
return true
}

// EndSync marks the sync operation as complete.
func (c *Cluster) EndSync() {
c.syncMu.Lock()
defer c.syncMu.Unlock()
c.syncRunning = false
}

// NeedsResync returns true if a resync was requested while sync was running,
// and clears the flag. Returns false if context is cancelled (deletion in progress).
func (c *Cluster) NeedsResync() bool {
c.syncMu.Lock()
defer c.syncMu.Unlock()

// Check if context is cancelled (deletion in progress)
select {
case <-c.ctx.Done():
return false
default:
}

result := c.needsResync
c.needsResync = false
return result
}

func (c *Cluster) clusterName() spec.NamespacedName {
return util.NameFromMeta(c.ObjectMeta)
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/cluster/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var logger = logrus.New().WithField("test", "cluster")
// 1 cluster, primary endpoint, 2 services, the secrets, the statefulset and pods being ready
var eventRecorder = record.NewFakeRecorder(7)

var cl = New(
var cl = New(context.Background(),
Config{
OpConfig: config.Config{
PodManagementPolicy: "ordered_ready",
Expand Down Expand Up @@ -135,7 +135,7 @@ func TestCreate(t *testing.T) {
client.Postgresqls(clusterNamespace).Create(context.TODO(), &pg, metav1.CreateOptions{})
client.Pods(clusterNamespace).Create(context.TODO(), &pod, metav1.CreateOptions{})

var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
PodManagementPolicy: "ordered_ready",
Expand Down Expand Up @@ -1607,7 +1607,7 @@ func TestCompareLogicalBackupJob(t *testing.T) {
},
}

var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
PodManagementPolicy: "ordered_ready",
Expand Down Expand Up @@ -1756,7 +1756,7 @@ func TestCrossNamespacedSecrets(t *testing.T) {
},
}

var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
ConnectionPooler: config.ConnectionPooler{
Expand Down
16 changes: 8 additions & 8 deletions pkg/cluster/connection_pooler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func noEmptySync(cluster *Cluster, err error, reason SyncReason) error {

func TestNeedConnectionPooler(t *testing.T) {
testName := "Test how connection pooler can be enabled"
var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
ProtectedRoles: []string{"admin"},
Expand Down Expand Up @@ -298,7 +298,7 @@ func TestConnectionPoolerCreateDeletion(t *testing.T) {
},
}

var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
ConnectionPooler: config.ConnectionPooler{
Expand Down Expand Up @@ -406,7 +406,7 @@ func TestConnectionPoolerSync(t *testing.T) {
},
}

var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
ConnectionPooler: config.ConnectionPooler{
Expand Down Expand Up @@ -668,7 +668,7 @@ func TestConnectionPoolerSync(t *testing.T) {

func TestConnectionPoolerPodSpec(t *testing.T) {
testName := "Test connection pooler pod template generation"
var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
ProtectedRoles: []string{"admin"},
Expand All @@ -691,7 +691,7 @@ func TestConnectionPoolerPodSpec(t *testing.T) {
ConnectionPooler: &acidv1.ConnectionPooler{},
EnableReplicaConnectionPooler: boolToPointer(true),
}
var clusterNoDefaultRes = New(
var clusterNoDefaultRes = New(context.Background(),
Config{
OpConfig: config.Config{
ProtectedRoles: []string{"admin"},
Expand Down Expand Up @@ -781,7 +781,7 @@ func TestConnectionPoolerPodSpec(t *testing.T) {

func TestConnectionPoolerDeploymentSpec(t *testing.T) {
testName := "Test connection pooler deployment spec generation"
var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
ProtectedRoles: []string{"admin"},
Expand Down Expand Up @@ -985,7 +985,7 @@ func TestPoolerTLS(t *testing.T) {
},
}

var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
PodManagementPolicy: "ordered_ready",
Expand Down Expand Up @@ -1066,7 +1066,7 @@ func TestPoolerTLS(t *testing.T) {

func TestConnectionPoolerServiceSpec(t *testing.T) {
testName := "Test connection pooler service spec generation"
var cluster = New(
var cluster = New(context.Background(),
Config{
OpConfig: config.Config{
ProtectedRoles: []string{"admin"},
Expand Down
39 changes: 17 additions & 22 deletions pkg/cluster/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cluster

import (
"bytes"
"context"
"database/sql"
"fmt"
"net"
Expand Down Expand Up @@ -121,26 +122,31 @@ func (c *Cluster) initDbConn() error {
if c.pgDb != nil {
return nil
}

return c.initDbConnWithName("")
}

// Worker function for connection initialization. This function does not check
// if the connection is already open, if it is then it will be overwritten.
// Callers need to make sure no connection is open, otherwise we could leak
// connections
// initDbConnWithName initializes a database connection using the cluster's context.
// This function does not check if the connection is already open.
func (c *Cluster) initDbConnWithName(dbname string) error {
return c.initDbConnWithNameContext(c.ctx, dbname)
}

// initDbConnWithNameContext initializes a database connection with an explicit context.
// Use this when you need a custom context (e.g., different timeout, or context.Background()
// for operations that should not be cancelled). This function does not check if the
// connection is already open, callers need to ensure no connection is open to avoid leaks.
func (c *Cluster) initDbConnWithNameContext(ctx context.Context, dbname string) error {
c.setProcessName("initializing db connection")

var conn *sql.DB
connstring := c.pgConnectionString(dbname)

finalerr := retryutil.Retry(constants.PostgresConnectTimeout, constants.PostgresConnectRetryTimeout,
finalerr := retryutil.RetryWithContext(ctx, constants.PostgresConnectTimeout, constants.PostgresConnectRetryTimeout,
func() (bool, error) {
var err error
conn, err = sql.Open("postgres", connstring)
if err == nil {
err = conn.Ping()
err = conn.PingContext(ctx)
}

if err == nil {
Expand Down Expand Up @@ -268,9 +274,7 @@ func findUsersFromRotation(rotatedUsers []string, db *sql.DB) (map[string]string
}()

for rows.Next() {
var (
rolname, roldatesuffix string
)
var rolname, roldatesuffix string
err := rows.Scan(&rolname, &roldatesuffix)
if err != nil {
return nil, fmt.Errorf("error when processing rows of deprecated users: %v", err)
Expand Down Expand Up @@ -331,9 +335,7 @@ func (c *Cluster) cleanupRotatedUsers(rotatedUsers []string) error {
// getDatabases returns the map of current databases with owners
// The caller is responsible for opening and closing the database connection
func (c *Cluster) getDatabases() (dbs map[string]string, err error) {
var (
rows *sql.Rows
)
var rows *sql.Rows

if rows, err = c.pgDb.Query(getDatabasesSQL); err != nil {
return nil, fmt.Errorf("could not query database: %v", err)
Expand Down Expand Up @@ -551,9 +553,7 @@ func (c *Cluster) getOwnerRoles(dbObjPath string, withUser bool) (owners []strin
// getExtension returns the list of current database extensions
// The caller is responsible for opening and closing the database connection
func (c *Cluster) getExtensions() (dbExtensions map[string]string, err error) {
var (
rows *sql.Rows
)
var rows *sql.Rows

if rows, err = c.pgDb.Query(getExtensionsSQL); err != nil {
return nil, fmt.Errorf("could not query database extensions: %v", err)
Expand Down Expand Up @@ -598,7 +598,6 @@ func (c *Cluster) executeAlterExtension(extName, schemaName string) error {
}

func (c *Cluster) execCreateOrAlterExtension(extName, schemaName, statement, doing, operation string) error {

c.logger.Infof("%s %q schema %q", doing, extName, schemaName)
if _, err := c.pgDb.Exec(fmt.Sprintf(statement, extName, schemaName)); err != nil {
return fmt.Errorf("could not execute %s: %v", operation, err)
Expand All @@ -610,9 +609,7 @@ func (c *Cluster) execCreateOrAlterExtension(extName, schemaName, statement, doi
// getPublications returns the list of current database publications with tables
// The caller is responsible for opening and closing the database connection
func (c *Cluster) getPublications() (publications map[string]string, err error) {
var (
rows *sql.Rows
)
var rows *sql.Rows

if rows, err = c.pgDb.Query(getPublicationsSQL); err != nil {
return nil, fmt.Errorf("could not query database publications: %v", err)
Expand Down Expand Up @@ -668,7 +665,6 @@ func (c *Cluster) executeAlterPublication(pubName, tableList string) error {
}

func (c *Cluster) execCreateOrAlterPublication(pubName, tableList, statement, doing, operation string) error {

c.logger.Debugf("%s %q with table list %q", doing, pubName, tableList)
if _, err := c.pgDb.Exec(fmt.Sprintf(statement, pubName, tableList)); err != nil {
return fmt.Errorf("could not execute %s: %v", operation, err)
Expand Down Expand Up @@ -743,7 +739,6 @@ func (c *Cluster) installLookupFunction(poolerSchema, poolerUser string) error {
constants.PostgresConnectTimeout,
constants.PostgresConnectRetryTimeout,
func() (bool, error) {

// At this moment we are not connected to any database
if err := c.initDbConnWithName(dbname); err != nil {
msg := "could not init database connection to %s"
Expand Down
Loading
Loading