Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require (
github.com/tidwall/gjson v1.18.0
github.com/trimble-oss/go-webauthn-client v0.3.0
golang.org/x/net v0.30.0
golang.org/x/sys v0.28.0
gopkg.in/ini.v1 v1.67.0
)

Expand Down Expand Up @@ -57,7 +58,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/term v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
42 changes: 41 additions & 1 deletion pkg/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package awsconfig

import (
"io/fs"
"os"
"path"
"path/filepath"
Expand Down Expand Up @@ -73,6 +74,15 @@ func (p *CredentialsProvider) Save(awsCreds *AWSCredentials) error {
return err
}

lock, err := newFileLock(filename)
if err != nil {
return errors.Wrap(err, "unable to create credentials file lock")
}
if err := lock.Lock(); err != nil {
return errors.Wrap(err, "unable to acquire credentials file lock")
}
defer lock.Unlock()

err = p.ensureConfigExists()
if err != nil {
if os.IsNotExist(err) {
Expand Down Expand Up @@ -246,5 +256,35 @@ func saveProfile(filename, profile string, awsCreds *AWSCredentials) error {
return err
}

return config.SaveTo(filename)
// Atomic write: write to temp file then rename to prevent partial writes
tmpFile, err := os.CreateTemp(filepath.Dir(filename), ".credentials.tmp.*")
if err != nil {
return errors.Wrap(err, "unable to create temp file")
}
tmpName := tmpFile.Name()

_, err = config.WriteTo(tmpFile)
if err != nil {
tmpFile.Close()
os.Remove(tmpName)
return errors.Wrap(err, "unable to write credentials to temp file")
}
if err := tmpFile.Close(); err != nil {
os.Remove(tmpName)
return errors.Wrap(err, "unable to close temp file")
}

// Preserve permissions of the original file
if fi, err := os.Stat(filename); err == nil {
_ = os.Chmod(tmpName, fi.Mode()&fs.ModePerm)
} else {
_ = os.Chmod(tmpName, 0600)
}

if err := os.Rename(tmpName, filename); err != nil {
os.Remove(tmpName)
return errors.Wrap(err, "unable to rename temp credentials file")
}

return nil
}
76 changes: 75 additions & 1 deletion pkg/awsconfig/awsconfig_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package awsconfig

import (
"fmt"
"os"
"strings"
"sync"
"testing"

"github.com/sirupsen/logrus"

"github.com/stretchr/testify/assert"
ini "gopkg.in/ini.v1"
)

func TestUpdateSamlConfig(t *testing.T) {
Expand Down Expand Up @@ -37,4 +40,75 @@ func TestUpdateSamlConfig(t *testing.T) {
assert.Equal(t, "testtoken", awsCreds.AWSSessionToken)

os.Remove(".credentials")
os.Remove(lockFilePath(".credentials"))
}

func TestConcurrentSave(t *testing.T) {
const filename = ".credentials_concurrent"
const numWorkers = 10

os.Remove(filename)
os.Remove(lockFilePath(filename))
defer os.Remove(filename)
defer os.Remove(lockFilePath(filename))

// Create the initial file so ensureConfigExists doesn't race
err := os.WriteFile(filename, []byte("[default]"), 0600)
assert.NoError(t, err)

var wg sync.WaitGroup
errCh := make(chan error, numWorkers)

for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
profile := fmt.Sprintf("profile-%d", idx)
creds := &CredentialsProvider{Filename: filename, Profile: profile}
awsCreds := &AWSCredentials{
AWSAccessKey: fmt.Sprintf("key-%d", idx),
AWSSecretKey: fmt.Sprintf("secret-%d", idx),
AWSSessionToken: fmt.Sprintf("token-%d", idx),
AWSSecurityToken: fmt.Sprintf("sectoken-%d", idx),
}
if err := creds.Save(awsCreds); err != nil {
errCh <- err
}
}(i)
}
wg.Wait()
close(errCh)

for err := range errCh {
t.Fatalf("concurrent save error: %v", err)
}

// Verify no duplicate sections
config, err := ini.Load(filename)
assert.NoError(t, err)

sectionNames := config.SectionStrings()
seen := make(map[string]bool)
for _, name := range sectionNames {
if seen[name] {
t.Errorf("duplicate section found: %s", name)
}
seen[name] = true
}

// Verify all profiles are present
for i := 0; i < numWorkers; i++ {
profile := fmt.Sprintf("profile-%d", i)
assert.True(t, seen[profile], "missing profile: %s", profile)
}

// Verify file content has no duplicate section headers via raw text
data, err := os.ReadFile(filename)
assert.NoError(t, err)
content := string(data)
for i := 0; i < numWorkers; i++ {
header := fmt.Sprintf("[profile-%d]", i)
count := strings.Count(content, header)
assert.Equal(t, 1, count, "expected exactly 1 occurrence of %s, got %d", header, count)
}
}
12 changes: 12 additions & 0 deletions pkg/awsconfig/filelock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package awsconfig

// fileLock provides mutual exclusion for file access across processes.
type fileLock interface {
Lock() error
Unlock() error
}

// lockFilePath returns the lock file path for a given file.
func lockFilePath(path string) string {
return path + ".lck"
}
61 changes: 61 additions & 0 deletions pkg/awsconfig/filelock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//go:build !windows
// +build !windows

package awsconfig

import (
"os"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestFileLockExclusivity(t *testing.T) {
tmpFile, err := os.CreateTemp("", "locktest")
assert.NoError(t, err)
defer os.Remove(tmpFile.Name())
tmpFile.Close()

lock1, err := newFileLock(tmpFile.Name())
assert.NoError(t, err)

err = lock1.Lock()
assert.NoError(t, err)

// Try to acquire a second lock in a goroutine — it should block
var mu sync.Mutex
acquired := false

go func() {
lock2, err := newFileLock(tmpFile.Name())
if err != nil {
return
}
_ = lock2.Lock()
mu.Lock()
acquired = true
mu.Unlock()
_ = lock2.Unlock()
}()

// Give the goroutine time to attempt the lock
time.Sleep(100 * time.Millisecond)
mu.Lock()
assert.False(t, acquired, "second lock should be blocked while first is held")
mu.Unlock()

// Release first lock
err = lock1.Unlock()
assert.NoError(t, err)

// Wait for second lock to acquire
time.Sleep(100 * time.Millisecond)
mu.Lock()
assert.True(t, acquired, "second lock should acquire after first is released")
mu.Unlock()

// Clean up lock file
os.Remove(lockFilePath(tmpFile.Name()))
}
31 changes: 31 additions & 0 deletions pkg/awsconfig/filelock_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//go:build !windows
// +build !windows

package awsconfig

import (
"os"

"golang.org/x/sys/unix"
)

type unixFileLock struct {
f *os.File
}

func newFileLock(path string) (fileLock, error) {
f, err := os.OpenFile(lockFilePath(path), os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
return nil, err
}
return &unixFileLock{f: f}, nil
}

func (l *unixFileLock) Lock() error {
return unix.Flock(int(l.f.Fd()), unix.LOCK_EX)
}

func (l *unixFileLock) Unlock() error {
defer l.f.Close()
return unix.Flock(int(l.f.Fd()), unix.LOCK_UN)
}
30 changes: 30 additions & 0 deletions pkg/awsconfig/filelock_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package awsconfig

import (
"os"

"golang.org/x/sys/windows"
)

type windowsFileLock struct {
f *os.File
}

func newFileLock(path string) (fileLock, error) {
f, err := os.OpenFile(lockFilePath(path), os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
return nil, err
}
return &windowsFileLock{f: f}, nil
}

func (l *windowsFileLock) Lock() error {
ol := new(windows.Overlapped)
return windows.LockFileEx(windows.Handle(l.f.Fd()), windows.LOCKFILE_EXCLUSIVE_LOCK, 0, 1, 0, ol)
}

func (l *windowsFileLock) Unlock() error {
ol := new(windows.Overlapped)
defer l.f.Close()
return windows.UnlockFileEx(windows.Handle(l.f.Fd()), 0, 1, 0, ol)
}