Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 74 additions & 28 deletions agent/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package cmd

import (
"context"
"flag"
"fmt"
"net/http"
Expand Down Expand Up @@ -112,7 +113,10 @@ func WithEffect(f func()) Option {
}

// Run runs the agent.
func Run(flags *Flags, opts ...Option) {
func Run(ctx context.Context, flags *Flags, opts ...Option) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

validateRequiredPorts(flags)

var overrides options
Expand All @@ -129,11 +133,11 @@ func Run(flags *Flags, opts ...Option) {
config = *overrides.config
} else {
if err := configutil.Load(flags.ConfigFile, &config); err != nil {
panic(err)
return err
}
if flags.SecretsFile != "" {
if err := configutil.Load(flags.SecretsFile, &config); err != nil {
panic(err)
return err
}
}
}
Expand All @@ -153,7 +157,7 @@ func Run(flags *Flags, opts ...Option) {
if stats == nil {
s, closer, err := metrics.New(config.Metrics, flags.KrakenCluster)
if err != nil {
log.Fatalf("Failed to init metrics: %s", err)
return fmt.Errorf("failed to init metrics: %s", err)
}
stats = s
defer closers.Close(closer)
Expand All @@ -162,7 +166,7 @@ func Run(flags *Flags, opts ...Option) {
if flags.PeerIP == "" {
localIP, err := netutil.GetLocalIP()
if err != nil {
log.Fatalf("Error getting local ip: %s", err)
return fmt.Errorf("error getting local ip: %s", err)
}
flags.PeerIP = localIP
}
Expand All @@ -174,40 +178,40 @@ func Run(flags *Flags, opts ...Option) {
pctx, err := core.NewPeerContext(
config.PeerIDFactory, flags.Zone, flags.KrakenCluster, flags.PeerIP, flags.PeerPort, false)
if err != nil {
log.Fatalf("Failed to create peer context: %s", err)
return fmt.Errorf("failed to create peer context: %s", err)
}

cads, err := store.NewCADownloadStore(config.CADownloadStore, stats)
if err != nil {
log.Fatalf("Failed to create local store: %s", err)
return fmt.Errorf("failed to create local store: %s", err)
}

netevents, err := networkevent.NewProducer(config.NetworkEvent)
if err != nil {
log.Fatalf("Failed to create network event producer: %s", err)
return fmt.Errorf("failed to create network event producer: %s", err)
}

trackers, err := config.Tracker.Build()
if err != nil {
log.Fatalf("Error building tracker upstream: %s", err)
return fmt.Errorf("error building tracker upstream: %s", err)
}
go trackers.Monitor(nil)
go trackers.Monitor(ctx.Done())

tls, err := config.TLS.BuildClient()
if err != nil {
log.Fatalf("Error building client tls config: %s", err)
return fmt.Errorf("error building client tls config: %s", err)
}

announceClient := announceclient.New(pctx, trackers, tls)
sched, err := scheduler.NewAgentScheduler(
config.Scheduler, stats, pctx, cads, netevents, trackers, announceClient, tls)
if err != nil {
log.Fatalf("Error creating scheduler: %s", err)
return fmt.Errorf("error creating scheduler: %s", err)
}

buildIndexes, err := config.BuildIndex.Build()
if err != nil {
log.Fatalf("Error building build-index upstream: %s", err)
return fmt.Errorf("error building build-index upstream: %s", err)
}

tagClient := tagclient.NewClusterClient(buildIndexes, tls)
Expand All @@ -216,7 +220,7 @@ func Run(flags *Flags, opts ...Option) {

registry, err := config.Registry.Build(config.Registry.ReadOnlyParameters(transferer, cads, stats))
if err != nil {
log.Fatalf("Failed to init registry: %s", err)
return fmt.Errorf("failed to init registry: %s", err)
}

registryAddr := fmt.Sprintf("127.0.0.1:%d", flags.AgentRegistryPort)
Expand All @@ -228,13 +232,14 @@ func Run(flags *Flags, opts ...Option) {
}
containerRuntimeFactory, err := containerruntime.NewFactory(containerRuntimeCfg, registryAddr)
if err != nil {
log.Fatalf("Failed to create container runtime factory: %s", err)
return fmt.Errorf("failed to create container runtime factory: %s", err)
}

agentServer := agentserver.New(
config.AgentServer, stats, cads, sched, tagClient, announceClient, containerRuntimeFactory)
addr := fmt.Sprintf(":%d", flags.AgentServerPort)
log.Infof("Starting agent server on %s", addr)
errCh := make(chan error, 3)
heartbeatTicker := &timeTicker{inner: time.NewTicker(10 * time.Second)}
heartbeatDone := make(chan struct{})
var heartbeatStop sync.Once
Expand All @@ -247,31 +252,72 @@ func Run(flags *Flags, opts ...Option) {

go heartbeat(stats, heartbeatTicker, heartbeatDone)
defer stopHeartbeat()

httpServer := &http.Server{Addr: addr, Handler: agentServer.Handler()}
go func() {
if err := http.ListenAndServe(addr, agentServer.Handler()); err != nil {
defer cancel()
// ErrServerClosed is returned by ListenAndServe when Shutdown() is
// called during a clean shutdown — it is expected, not a real error.
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
Comment on lines 257 to +261
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should cancel on all errors, right?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sambhav-jain-16

From https://pkg.go.dev/net/http#Server.Shutdown

" Shutdown does not attempt to close nor wait for hijacked connections such as WebSockets. [...] When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS immediately return ErrServerClosed. Make sure the program doesn't exit and waits instead for Shutdown to return."

http.ErrServerClosed is returned by ListenAndServe exactly when Shutdown() or Close() is called on the server. Here's the flow in this code:

  • OS sends SIGTERM/SIGINT → signal.NotifyContext in main.go cancels the top-level ctx
  • ctx.Done() fires → the shutdown goroutine calls httpServer.Shutdown(context.Background())
  • Shutdown() drains in-flight requests, then causes ListenAndServe to return http.ErrServerClosed

stopHeartbeat()
log.Fatal(err)
log.Errorf("agent server exited: %s", err)
errCh <- err
}
}()

log.Info("Starting registry...")
go func() {
defer cancel()
if err := registry.ListenAndServe(); err != nil {
stopHeartbeat()
log.Fatal(err)
log.Errorf("registry exited: %s", err)
errCh <- err
}
}()

go func() {
defer cancel()
if err := nginx.RunContext(ctx, config.Nginx, map[string]interface{}{
"allowed_cidrs": config.AllowedCidrs,
"port": flags.AgentRegistryPort,
"registry_server": nginx.GetServer(
Comment on lines +278 to +283
Comment on lines +278 to +283
config.Registry.Docker.HTTP.Net, config.Registry.Docker.HTTP.Addr),
"agent_server": fmt.Sprintf("127.0.0.1:%d", flags.AgentServerPort),
"registry_backup": config.RegistryBackup},
nginx.WithTLS(config.TLS)); err != nil {
stopHeartbeat()
log.Errorf("nginx exited: %s", err)
errCh <- err
}
}()

if err := nginx.Run(config.Nginx, map[string]interface{}{
"allowed_cidrs": config.AllowedCidrs,
"port": flags.AgentRegistryPort,
"registry_server": nginx.GetServer(
config.Registry.Docker.HTTP.Net, config.Registry.Docker.HTTP.Addr),
"agent_server": fmt.Sprintf("127.0.0.1:%d", flags.AgentServerPort),
"registry_backup": config.RegistryBackup},
nginx.WithTLS(config.TLS)); err != nil {
stopHeartbeat()
log.Fatal(err)
runErr := waitForShutdown(ctx, errCh)
// Drain in-flight HTTP requests before returning. cancel() is called first
// so Shutdown does not block indefinitely if we are exiting due to an error
// rather than a signal.
cancel()
if err := httpServer.Shutdown(context.Background()); err != nil {
log.Errorf("agent server shutdown: %s", err)
}
return runErr
}

// waitForShutdown blocks until ctx is cancelled or an error arrives on errCh.
// Goroutines always send to errCh before calling cancel(), so by the time
// ctx.Done() is observed the error is already buffered and the non-blocking
// drain will always retrieve it.
func waitForShutdown(ctx context.Context, errCh <-chan error) error {
select {
case <-ctx.Done():
select {
case err := <-errCh:
return err
default:
}
log.Infof("shutting down: %s", ctx.Err())
return nil
case err := <-errCh:
return err
}
}

Expand Down
106 changes: 83 additions & 23 deletions agent/cmd/cmd_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package cmd

import (
"context"
"errors"
"flag"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
Expand Down Expand Up @@ -112,7 +115,7 @@ func TestRunValidation(t *testing.T) {
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
assert.PanicsWithValue(t, test.panic, func() {
Run(&test.flags)
_ = Run(context.Background(), &test.flags) //nolint:errcheck
})
})
}
Expand All @@ -131,7 +134,8 @@ func TestRunUsesProvidedConfig(t *testing.T) {
called := false

assert.PanicsWithValue(t, sentinel, func() {
Run(
_ = Run( //nolint:errcheck
context.Background(),
flags,
WithConfig(Config{}),
WithMetrics(tally.NewTestScope("", nil)),
Expand All @@ -146,7 +150,7 @@ func TestRunUsesProvidedConfig(t *testing.T) {
assert.True(t, called, "effect should be invoked")
}

func TestRunPanicsWhenConfigLoadFails(t *testing.T) {
func TestRunReturnsErrorWhenConfigLoadFails(t *testing.T) {
missing := filepath.Join(t.TempDir(), "missing.yaml")

flags := &Flags{
Expand All @@ -156,18 +160,17 @@ func TestRunPanicsWhenConfigLoadFails(t *testing.T) {
ConfigFile: missing,
}

expected := fmt.Sprintf("open %s: no such file or directory", missing)

assert.PanicsWithError(t, expected, func() {
Run(
flags,
WithMetrics(tally.NewTestScope("", nil)),
WithLogger(zap.NewNop()),
)
})
err := Run(
context.Background(),
flags,
WithMetrics(tally.NewTestScope("", nil)),
WithLogger(zap.NewNop()),
)
require.Error(t, err)
assert.Contains(t, err.Error(), "no such file or directory")
}

func TestRunPanicsWhenSecretsLoadFails(t *testing.T) {
func TestRunReturnsErrorWhenSecretsLoadFails(t *testing.T) {
_, filename, _, ok := runtime.Caller(0)
require.True(t, ok)

Expand All @@ -185,15 +188,14 @@ func TestRunPanicsWhenSecretsLoadFails(t *testing.T) {
SecretsFile: missingSecrets,
}

expected := fmt.Sprintf("open %s: no such file or directory", missingSecrets)

assert.PanicsWithError(t, expected, func() {
Run(
flags,
WithMetrics(tally.NewTestScope("", nil)),
WithLogger(zap.NewNop()),
)
})
err := Run(
context.Background(),
flags,
WithMetrics(tally.NewTestScope("", nil)),
WithLogger(zap.NewNop()),
)
require.Error(t, err)
assert.Contains(t, err.Error(), "no such file or directory")
}

func TestValidateRequiredPorts(t *testing.T) {
Expand Down Expand Up @@ -288,3 +290,61 @@ func (t clockTicker) Chan() <-chan time.Time {
func (t clockTicker) Stop() {
t.ticker.Stop()
}

func TestWaitForShutdown_ExternalCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 2)
cancel()
err := waitForShutdown(ctx, errCh)
assert.NoError(t, err)
}

func TestWaitForShutdown_ErrorReceived(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 2)
sentinel := errors.New("internal failure")
errCh <- sentinel
err := waitForShutdown(ctx, errCh)
assert.Equal(t, sentinel, err)
}

func TestWaitForShutdown_ErrorWinsWhenBothReady(t *testing.T) {
// When cancel() and errCh <- err happen together (the common internal-error
// path), the error must always be returned — never silently swallowed as nil.
sentinel := errors.New("goroutine failure")
for range 100 {
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 2)
cancel()
errCh <- sentinel
err := waitForShutdown(ctx, errCh)
require.Equal(t, sentinel, err, "error must not be swallowed when both ctx and errCh are ready")
}
}

// TestHTTPServerGracefulShutdown verifies the pattern used in Run:
// ListenAndServe stops cleanly when Shutdown is called, and http.ErrServerClosed
// is not surfaced as a fatal error.
func TestHTTPServerGracefulShutdown(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})}

serveErr := make(chan error, 1)
go func() {
serveErr <- srv.Serve(ln)
}()

// Shutdown should cause Serve to return http.ErrServerClosed — not a real error.
require.NoError(t, srv.Shutdown(context.Background()))

err = <-serveErr
assert.Equal(t, http.ErrServerClosed, err, "Serve should return ErrServerClosed after Shutdown")

// Verify the ErrServerClosed guard used in Run filters it correctly.
if err != nil && err != http.ErrServerClosed {
t.Fatal("guard would have incorrectly treated ErrServerClosed as a fatal error")
}
}
15 changes: 13 additions & 2 deletions agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,23 @@
package main

import (
"context"
"log"
"os"
"os/signal"
"syscall"

"github.com/uber/kraken/agent/cmd"
"github.com/uber/kraken/lib/dockerregistry"
)

func main() {
cmd.Run(cmd.ParseFlags(), cmd.WithEffect(func() {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't change the exit code mechanism.

Copy link
Copy Markdown
Author

@dajneem23 dajneem23 May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't change the exit code mechanism.

Agreed — and we didn't. The exit code mechanism is unchanged from master. log.Fatal(err) is still the last line in main, so any fatal error still terminates with exit code 1 via os.Exit(1) exactly as before.
cmd.Run needs a context. Context that is cancelled on SIGINT/SIGTERM.

cmd.Run now takes ctx—the context is passed down so internal goroutines (HTTP server, nginx) can shut down gracefully when a signal arrives, instead of being killed abruptly.

Is this acceptable?


if err := cmd.Run(ctx, cmd.ParseFlags(), cmd.WithEffect(func() {
dockerregistry.RegisterKrakenStorageDriver()
}))
})); err != nil {
log.Fatal(err)
}
}
Loading