From 45e68af4e7bf7d4314f7716393fe14e92f40b407 Mon Sep 17 00:00:00 2001 From: Jay Janssen Date: Thu, 5 Mar 2026 13:47:05 -0500 Subject: [PATCH] fix(awsconfig): add file locking to prevent duplicate credential sections Implement exclusive file locking and atomic writes for the AWS credentials file to prevent duplicate section headers when multiple saml2aws processes run concurrently. Use flock on Unix/Linux/macOS and LockFileEx on Windows via golang.org/x/sys. Add TestConcurrentSave to verify no duplicates under load. Fixes concurrent race condition in Save() and saveProfile(). Co-Authored-By: Claude Opus 4.6 Co-authored-by: Claude Code Ai-assisted: true --- go.mod | 2 +- pkg/awsconfig/awsconfig.go | 42 ++++++++++++++++- pkg/awsconfig/awsconfig_test.go | 76 ++++++++++++++++++++++++++++++- pkg/awsconfig/filelock.go | 12 +++++ pkg/awsconfig/filelock_test.go | 61 +++++++++++++++++++++++++ pkg/awsconfig/filelock_unix.go | 31 +++++++++++++ pkg/awsconfig/filelock_windows.go | 30 ++++++++++++ 7 files changed, 251 insertions(+), 3 deletions(-) create mode 100644 pkg/awsconfig/filelock.go create mode 100644 pkg/awsconfig/filelock_test.go create mode 100644 pkg/awsconfig/filelock_unix.go create mode 100644 pkg/awsconfig/filelock_windows.go diff --git a/go.mod b/go.mod index 6c17bf894..11dc65380 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/trimble-oss/go-webauthn-client v0.3.0 golang.org/x/net v0.30.0 + golang.org/x/sys v0.28.0 gopkg.in/ini.v1 v1.67.0 ) @@ -57,7 +58,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect - golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/pkg/awsconfig/awsconfig.go b/pkg/awsconfig/awsconfig.go index b1bfef9dc..43d28c081 100644 --- a/pkg/awsconfig/awsconfig.go +++ b/pkg/awsconfig/awsconfig.go @@ -1,6 +1,7 @@ package awsconfig import ( + "io/fs" "os" "path" "path/filepath" @@ -73,6 +74,15 @@ func (p *CredentialsProvider) Save(awsCreds *AWSCredentials) error { return err } + lock, err := newFileLock(filename) + if err != nil { + return errors.Wrap(err, "unable to create credentials file lock") + } + if err := lock.Lock(); err != nil { + return errors.Wrap(err, "unable to acquire credentials file lock") + } + defer lock.Unlock() + err = p.ensureConfigExists() if err != nil { if os.IsNotExist(err) { @@ -246,5 +256,35 @@ func saveProfile(filename, profile string, awsCreds *AWSCredentials) error { return err } - return config.SaveTo(filename) + // Atomic write: write to temp file then rename to prevent partial writes + tmpFile, err := os.CreateTemp(filepath.Dir(filename), ".credentials.tmp.*") + if err != nil { + return errors.Wrap(err, "unable to create temp file") + } + tmpName := tmpFile.Name() + + _, err = config.WriteTo(tmpFile) + if err != nil { + tmpFile.Close() + os.Remove(tmpName) + return errors.Wrap(err, "unable to write credentials to temp file") + } + if err := tmpFile.Close(); err != nil { + os.Remove(tmpName) + return errors.Wrap(err, "unable to close temp file") + } + + // Preserve permissions of the original file + if fi, err := os.Stat(filename); err == nil { + _ = os.Chmod(tmpName, fi.Mode()&fs.ModePerm) + } else { + _ = os.Chmod(tmpName, 0600) + } + + if err := os.Rename(tmpName, filename); err != nil { + os.Remove(tmpName) + return errors.Wrap(err, "unable to rename temp credentials file") + } + + return nil } diff --git a/pkg/awsconfig/awsconfig_test.go b/pkg/awsconfig/awsconfig_test.go index 7d8c5f8b5..5bc7b5e10 100644 --- a/pkg/awsconfig/awsconfig_test.go +++ b/pkg/awsconfig/awsconfig_test.go @@ -1,12 +1,15 @@ package awsconfig import ( + "fmt" "os" + "strings" + "sync" "testing" "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" + ini "gopkg.in/ini.v1" ) func TestUpdateSamlConfig(t *testing.T) { @@ -37,4 +40,75 @@ func TestUpdateSamlConfig(t *testing.T) { assert.Equal(t, "testtoken", awsCreds.AWSSessionToken) os.Remove(".credentials") + os.Remove(lockFilePath(".credentials")) +} + +func TestConcurrentSave(t *testing.T) { + const filename = ".credentials_concurrent" + const numWorkers = 10 + + os.Remove(filename) + os.Remove(lockFilePath(filename)) + defer os.Remove(filename) + defer os.Remove(lockFilePath(filename)) + + // Create the initial file so ensureConfigExists doesn't race + err := os.WriteFile(filename, []byte("[default]"), 0600) + assert.NoError(t, err) + + var wg sync.WaitGroup + errCh := make(chan error, numWorkers) + + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + profile := fmt.Sprintf("profile-%d", idx) + creds := &CredentialsProvider{Filename: filename, Profile: profile} + awsCreds := &AWSCredentials{ + AWSAccessKey: fmt.Sprintf("key-%d", idx), + AWSSecretKey: fmt.Sprintf("secret-%d", idx), + AWSSessionToken: fmt.Sprintf("token-%d", idx), + AWSSecurityToken: fmt.Sprintf("sectoken-%d", idx), + } + if err := creds.Save(awsCreds); err != nil { + errCh <- err + } + }(i) + } + wg.Wait() + close(errCh) + + for err := range errCh { + t.Fatalf("concurrent save error: %v", err) + } + + // Verify no duplicate sections + config, err := ini.Load(filename) + assert.NoError(t, err) + + sectionNames := config.SectionStrings() + seen := make(map[string]bool) + for _, name := range sectionNames { + if seen[name] { + t.Errorf("duplicate section found: %s", name) + } + seen[name] = true + } + + // Verify all profiles are present + for i := 0; i < numWorkers; i++ { + profile := fmt.Sprintf("profile-%d", i) + assert.True(t, seen[profile], "missing profile: %s", profile) + } + + // Verify file content has no duplicate section headers via raw text + data, err := os.ReadFile(filename) + assert.NoError(t, err) + content := string(data) + for i := 0; i < numWorkers; i++ { + header := fmt.Sprintf("[profile-%d]", i) + count := strings.Count(content, header) + assert.Equal(t, 1, count, "expected exactly 1 occurrence of %s, got %d", header, count) + } } diff --git a/pkg/awsconfig/filelock.go b/pkg/awsconfig/filelock.go new file mode 100644 index 000000000..f4ee85d73 --- /dev/null +++ b/pkg/awsconfig/filelock.go @@ -0,0 +1,12 @@ +package awsconfig + +// fileLock provides mutual exclusion for file access across processes. +type fileLock interface { + Lock() error + Unlock() error +} + +// lockFilePath returns the lock file path for a given file. +func lockFilePath(path string) string { + return path + ".lck" +} diff --git a/pkg/awsconfig/filelock_test.go b/pkg/awsconfig/filelock_test.go new file mode 100644 index 000000000..a25dac6de --- /dev/null +++ b/pkg/awsconfig/filelock_test.go @@ -0,0 +1,61 @@ +//go:build !windows +// +build !windows + +package awsconfig + +import ( + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestFileLockExclusivity(t *testing.T) { + tmpFile, err := os.CreateTemp("", "locktest") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + tmpFile.Close() + + lock1, err := newFileLock(tmpFile.Name()) + assert.NoError(t, err) + + err = lock1.Lock() + assert.NoError(t, err) + + // Try to acquire a second lock in a goroutine — it should block + var mu sync.Mutex + acquired := false + + go func() { + lock2, err := newFileLock(tmpFile.Name()) + if err != nil { + return + } + _ = lock2.Lock() + mu.Lock() + acquired = true + mu.Unlock() + _ = lock2.Unlock() + }() + + // Give the goroutine time to attempt the lock + time.Sleep(100 * time.Millisecond) + mu.Lock() + assert.False(t, acquired, "second lock should be blocked while first is held") + mu.Unlock() + + // Release first lock + err = lock1.Unlock() + assert.NoError(t, err) + + // Wait for second lock to acquire + time.Sleep(100 * time.Millisecond) + mu.Lock() + assert.True(t, acquired, "second lock should acquire after first is released") + mu.Unlock() + + // Clean up lock file + os.Remove(lockFilePath(tmpFile.Name())) +} diff --git a/pkg/awsconfig/filelock_unix.go b/pkg/awsconfig/filelock_unix.go new file mode 100644 index 000000000..d223f9b69 --- /dev/null +++ b/pkg/awsconfig/filelock_unix.go @@ -0,0 +1,31 @@ +//go:build !windows +// +build !windows + +package awsconfig + +import ( + "os" + + "golang.org/x/sys/unix" +) + +type unixFileLock struct { + f *os.File +} + +func newFileLock(path string) (fileLock, error) { + f, err := os.OpenFile(lockFilePath(path), os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + return &unixFileLock{f: f}, nil +} + +func (l *unixFileLock) Lock() error { + return unix.Flock(int(l.f.Fd()), unix.LOCK_EX) +} + +func (l *unixFileLock) Unlock() error { + defer l.f.Close() + return unix.Flock(int(l.f.Fd()), unix.LOCK_UN) +} diff --git a/pkg/awsconfig/filelock_windows.go b/pkg/awsconfig/filelock_windows.go new file mode 100644 index 000000000..6bea517b6 --- /dev/null +++ b/pkg/awsconfig/filelock_windows.go @@ -0,0 +1,30 @@ +package awsconfig + +import ( + "os" + + "golang.org/x/sys/windows" +) + +type windowsFileLock struct { + f *os.File +} + +func newFileLock(path string) (fileLock, error) { + f, err := os.OpenFile(lockFilePath(path), os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + return &windowsFileLock{f: f}, nil +} + +func (l *windowsFileLock) Lock() error { + ol := new(windows.Overlapped) + return windows.LockFileEx(windows.Handle(l.f.Fd()), windows.LOCKFILE_EXCLUSIVE_LOCK, 0, 1, 0, ol) +} + +func (l *windowsFileLock) Unlock() error { + ol := new(windows.Overlapped) + defer l.f.Close() + return windows.UnlockFileEx(windows.Handle(l.f.Fd()), 0, 1, 0, ol) +}