diff --git a/go.mod b/go.mod index 6c17bf894..11dc65380 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/trimble-oss/go-webauthn-client v0.3.0 golang.org/x/net v0.30.0 + golang.org/x/sys v0.28.0 gopkg.in/ini.v1 v1.67.0 ) @@ -57,7 +58,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect - golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/pkg/awsconfig/awsconfig.go b/pkg/awsconfig/awsconfig.go index b1bfef9dc..43d28c081 100644 --- a/pkg/awsconfig/awsconfig.go +++ b/pkg/awsconfig/awsconfig.go @@ -1,6 +1,7 @@ package awsconfig import ( + "io/fs" "os" "path" "path/filepath" @@ -73,6 +74,15 @@ func (p *CredentialsProvider) Save(awsCreds *AWSCredentials) error { return err } + lock, err := newFileLock(filename) + if err != nil { + return errors.Wrap(err, "unable to create credentials file lock") + } + if err := lock.Lock(); err != nil { + return errors.Wrap(err, "unable to acquire credentials file lock") + } + defer lock.Unlock() + err = p.ensureConfigExists() if err != nil { if os.IsNotExist(err) { @@ -246,5 +256,35 @@ func saveProfile(filename, profile string, awsCreds *AWSCredentials) error { return err } - return config.SaveTo(filename) + // Atomic write: write to temp file then rename to prevent partial writes + tmpFile, err := os.CreateTemp(filepath.Dir(filename), ".credentials.tmp.*") + if err != nil { + return errors.Wrap(err, "unable to create temp file") + } + tmpName := tmpFile.Name() + + _, err = config.WriteTo(tmpFile) + if err != nil { + tmpFile.Close() + os.Remove(tmpName) + return errors.Wrap(err, "unable to write credentials to temp file") + } + if err := tmpFile.Close(); err != nil { + os.Remove(tmpName) + return errors.Wrap(err, "unable to close temp file") + } + + // Preserve permissions of the original file + if fi, err := os.Stat(filename); err == nil { + _ = os.Chmod(tmpName, fi.Mode()&fs.ModePerm) + } else { + _ = os.Chmod(tmpName, 0600) + } + + if err := os.Rename(tmpName, filename); err != nil { + os.Remove(tmpName) + return errors.Wrap(err, "unable to rename temp credentials file") + } + + return nil } diff --git a/pkg/awsconfig/awsconfig_test.go b/pkg/awsconfig/awsconfig_test.go index 7d8c5f8b5..5bc7b5e10 100644 --- a/pkg/awsconfig/awsconfig_test.go +++ b/pkg/awsconfig/awsconfig_test.go @@ -1,12 +1,15 @@ package awsconfig import ( + "fmt" "os" + "strings" + "sync" "testing" "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" + ini "gopkg.in/ini.v1" ) func TestUpdateSamlConfig(t *testing.T) { @@ -37,4 +40,75 @@ func TestUpdateSamlConfig(t *testing.T) { assert.Equal(t, "testtoken", awsCreds.AWSSessionToken) os.Remove(".credentials") + os.Remove(lockFilePath(".credentials")) +} + +func TestConcurrentSave(t *testing.T) { + const filename = ".credentials_concurrent" + const numWorkers = 10 + + os.Remove(filename) + os.Remove(lockFilePath(filename)) + defer os.Remove(filename) + defer os.Remove(lockFilePath(filename)) + + // Create the initial file so ensureConfigExists doesn't race + err := os.WriteFile(filename, []byte("[default]"), 0600) + assert.NoError(t, err) + + var wg sync.WaitGroup + errCh := make(chan error, numWorkers) + + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + profile := fmt.Sprintf("profile-%d", idx) + creds := &CredentialsProvider{Filename: filename, Profile: profile} + awsCreds := &AWSCredentials{ + AWSAccessKey: fmt.Sprintf("key-%d", idx), + AWSSecretKey: fmt.Sprintf("secret-%d", idx), + AWSSessionToken: fmt.Sprintf("token-%d", idx), + AWSSecurityToken: fmt.Sprintf("sectoken-%d", idx), + } + if err := creds.Save(awsCreds); err != nil { + errCh <- err + } + }(i) + } + wg.Wait() + close(errCh) + + for err := range errCh { + t.Fatalf("concurrent save error: %v", err) + } + + // Verify no duplicate sections + config, err := ini.Load(filename) + assert.NoError(t, err) + + sectionNames := config.SectionStrings() + seen := make(map[string]bool) + for _, name := range sectionNames { + if seen[name] { + t.Errorf("duplicate section found: %s", name) + } + seen[name] = true + } + + // Verify all profiles are present + for i := 0; i < numWorkers; i++ { + profile := fmt.Sprintf("profile-%d", i) + assert.True(t, seen[profile], "missing profile: %s", profile) + } + + // Verify file content has no duplicate section headers via raw text + data, err := os.ReadFile(filename) + assert.NoError(t, err) + content := string(data) + for i := 0; i < numWorkers; i++ { + header := fmt.Sprintf("[profile-%d]", i) + count := strings.Count(content, header) + assert.Equal(t, 1, count, "expected exactly 1 occurrence of %s, got %d", header, count) + } } diff --git a/pkg/awsconfig/filelock.go b/pkg/awsconfig/filelock.go new file mode 100644 index 000000000..f4ee85d73 --- /dev/null +++ b/pkg/awsconfig/filelock.go @@ -0,0 +1,12 @@ +package awsconfig + +// fileLock provides mutual exclusion for file access across processes. +type fileLock interface { + Lock() error + Unlock() error +} + +// lockFilePath returns the lock file path for a given file. +func lockFilePath(path string) string { + return path + ".lck" +} diff --git a/pkg/awsconfig/filelock_test.go b/pkg/awsconfig/filelock_test.go new file mode 100644 index 000000000..a25dac6de --- /dev/null +++ b/pkg/awsconfig/filelock_test.go @@ -0,0 +1,61 @@ +//go:build !windows +// +build !windows + +package awsconfig + +import ( + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestFileLockExclusivity(t *testing.T) { + tmpFile, err := os.CreateTemp("", "locktest") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + tmpFile.Close() + + lock1, err := newFileLock(tmpFile.Name()) + assert.NoError(t, err) + + err = lock1.Lock() + assert.NoError(t, err) + + // Try to acquire a second lock in a goroutine — it should block + var mu sync.Mutex + acquired := false + + go func() { + lock2, err := newFileLock(tmpFile.Name()) + if err != nil { + return + } + _ = lock2.Lock() + mu.Lock() + acquired = true + mu.Unlock() + _ = lock2.Unlock() + }() + + // Give the goroutine time to attempt the lock + time.Sleep(100 * time.Millisecond) + mu.Lock() + assert.False(t, acquired, "second lock should be blocked while first is held") + mu.Unlock() + + // Release first lock + err = lock1.Unlock() + assert.NoError(t, err) + + // Wait for second lock to acquire + time.Sleep(100 * time.Millisecond) + mu.Lock() + assert.True(t, acquired, "second lock should acquire after first is released") + mu.Unlock() + + // Clean up lock file + os.Remove(lockFilePath(tmpFile.Name())) +} diff --git a/pkg/awsconfig/filelock_unix.go b/pkg/awsconfig/filelock_unix.go new file mode 100644 index 000000000..d223f9b69 --- /dev/null +++ b/pkg/awsconfig/filelock_unix.go @@ -0,0 +1,31 @@ +//go:build !windows +// +build !windows + +package awsconfig + +import ( + "os" + + "golang.org/x/sys/unix" +) + +type unixFileLock struct { + f *os.File +} + +func newFileLock(path string) (fileLock, error) { + f, err := os.OpenFile(lockFilePath(path), os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + return &unixFileLock{f: f}, nil +} + +func (l *unixFileLock) Lock() error { + return unix.Flock(int(l.f.Fd()), unix.LOCK_EX) +} + +func (l *unixFileLock) Unlock() error { + defer l.f.Close() + return unix.Flock(int(l.f.Fd()), unix.LOCK_UN) +} diff --git a/pkg/awsconfig/filelock_windows.go b/pkg/awsconfig/filelock_windows.go new file mode 100644 index 000000000..6bea517b6 --- /dev/null +++ b/pkg/awsconfig/filelock_windows.go @@ -0,0 +1,30 @@ +package awsconfig + +import ( + "os" + + "golang.org/x/sys/windows" +) + +type windowsFileLock struct { + f *os.File +} + +func newFileLock(path string) (fileLock, error) { + f, err := os.OpenFile(lockFilePath(path), os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + return &windowsFileLock{f: f}, nil +} + +func (l *windowsFileLock) Lock() error { + ol := new(windows.Overlapped) + return windows.LockFileEx(windows.Handle(l.f.Fd()), windows.LOCKFILE_EXCLUSIVE_LOCK, 0, 1, 0, ol) +} + +func (l *windowsFileLock) Unlock() error { + ol := new(windows.Overlapped) + defer l.f.Close() + return windows.UnlockFileEx(windows.Handle(l.f.Fd()), 0, 1, 0, ol) +}