diff --git a/agent/cmd/cmd.go b/agent/cmd/cmd.go index b887d6b45..80b81f206 100644 --- a/agent/cmd/cmd.go +++ b/agent/cmd/cmd.go @@ -14,6 +14,7 @@ package cmd import ( + "context" "flag" "fmt" "net/http" @@ -112,7 +113,10 @@ func WithEffect(f func()) Option { } // Run runs the agent. -func Run(flags *Flags, opts ...Option) { +func Run(ctx context.Context, flags *Flags, opts ...Option) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + validateRequiredPorts(flags) var overrides options @@ -129,11 +133,11 @@ func Run(flags *Flags, opts ...Option) { config = *overrides.config } else { if err := configutil.Load(flags.ConfigFile, &config); err != nil { - panic(err) + return err } if flags.SecretsFile != "" { if err := configutil.Load(flags.SecretsFile, &config); err != nil { - panic(err) + return err } } } @@ -153,7 +157,7 @@ func Run(flags *Flags, opts ...Option) { if stats == nil { s, closer, err := metrics.New(config.Metrics, flags.KrakenCluster) if err != nil { - log.Fatalf("Failed to init metrics: %s", err) + return fmt.Errorf("failed to init metrics: %s", err) } stats = s defer closers.Close(closer) @@ -162,7 +166,7 @@ func Run(flags *Flags, opts ...Option) { if flags.PeerIP == "" { localIP, err := netutil.GetLocalIP() if err != nil { - log.Fatalf("Error getting local ip: %s", err) + return fmt.Errorf("error getting local ip: %s", err) } flags.PeerIP = localIP } @@ -174,40 +178,40 @@ func Run(flags *Flags, opts ...Option) { pctx, err := core.NewPeerContext( config.PeerIDFactory, flags.Zone, flags.KrakenCluster, flags.PeerIP, flags.PeerPort, false) if err != nil { - log.Fatalf("Failed to create peer context: %s", err) + return fmt.Errorf("failed to create peer context: %s", err) } cads, err := store.NewCADownloadStore(config.CADownloadStore, stats) if err != nil { - log.Fatalf("Failed to create local store: %s", err) + return fmt.Errorf("failed to create local store: %s", err) } netevents, err := networkevent.NewProducer(config.NetworkEvent) if err != nil { - log.Fatalf("Failed to create network event producer: %s", err) + return fmt.Errorf("failed to create network event producer: %s", err) } trackers, err := config.Tracker.Build() if err != nil { - log.Fatalf("Error building tracker upstream: %s", err) + return fmt.Errorf("error building tracker upstream: %s", err) } - go trackers.Monitor(nil) + go trackers.Monitor(ctx.Done()) tls, err := config.TLS.BuildClient() if err != nil { - log.Fatalf("Error building client tls config: %s", err) + return fmt.Errorf("error building client tls config: %s", err) } announceClient := announceclient.New(pctx, trackers, tls) sched, err := scheduler.NewAgentScheduler( config.Scheduler, stats, pctx, cads, netevents, trackers, announceClient, tls) if err != nil { - log.Fatalf("Error creating scheduler: %s", err) + return fmt.Errorf("error creating scheduler: %s", err) } buildIndexes, err := config.BuildIndex.Build() if err != nil { - log.Fatalf("Error building build-index upstream: %s", err) + return fmt.Errorf("error building build-index upstream: %s", err) } tagClient := tagclient.NewClusterClient(buildIndexes, tls) @@ -216,7 +220,7 @@ func Run(flags *Flags, opts ...Option) { registry, err := config.Registry.Build(config.Registry.ReadOnlyParameters(transferer, cads, stats)) if err != nil { - log.Fatalf("Failed to init registry: %s", err) + return fmt.Errorf("failed to init registry: %s", err) } registryAddr := fmt.Sprintf("127.0.0.1:%d", flags.AgentRegistryPort) @@ -228,13 +232,14 @@ func Run(flags *Flags, opts ...Option) { } containerRuntimeFactory, err := containerruntime.NewFactory(containerRuntimeCfg, registryAddr) if err != nil { - log.Fatalf("Failed to create container runtime factory: %s", err) + return fmt.Errorf("failed to create container runtime factory: %s", err) } agentServer := agentserver.New( config.AgentServer, stats, cads, sched, tagClient, announceClient, containerRuntimeFactory) addr := fmt.Sprintf(":%d", flags.AgentServerPort) log.Infof("Starting agent server on %s", addr) + errCh := make(chan error, 3) heartbeatTicker := &timeTicker{inner: time.NewTicker(10 * time.Second)} heartbeatDone := make(chan struct{}) var heartbeatStop sync.Once @@ -247,31 +252,72 @@ func Run(flags *Flags, opts ...Option) { go heartbeat(stats, heartbeatTicker, heartbeatDone) defer stopHeartbeat() + + httpServer := &http.Server{Addr: addr, Handler: agentServer.Handler()} go func() { - if err := http.ListenAndServe(addr, agentServer.Handler()); err != nil { + defer cancel() + // ErrServerClosed is returned by ListenAndServe when Shutdown() is + // called during a clean shutdown — it is expected, not a real error. + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { stopHeartbeat() - log.Fatal(err) + log.Errorf("agent server exited: %s", err) + errCh <- err } }() log.Info("Starting registry...") go func() { + defer cancel() if err := registry.ListenAndServe(); err != nil { stopHeartbeat() - log.Fatal(err) + log.Errorf("registry exited: %s", err) + errCh <- err + } + }() + + go func() { + defer cancel() + if err := nginx.RunContext(ctx, config.Nginx, map[string]interface{}{ + "allowed_cidrs": config.AllowedCidrs, + "port": flags.AgentRegistryPort, + "registry_server": nginx.GetServer( + config.Registry.Docker.HTTP.Net, config.Registry.Docker.HTTP.Addr), + "agent_server": fmt.Sprintf("127.0.0.1:%d", flags.AgentServerPort), + "registry_backup": config.RegistryBackup}, + nginx.WithTLS(config.TLS)); err != nil { + stopHeartbeat() + log.Errorf("nginx exited: %s", err) + errCh <- err } }() - if err := nginx.Run(config.Nginx, map[string]interface{}{ - "allowed_cidrs": config.AllowedCidrs, - "port": flags.AgentRegistryPort, - "registry_server": nginx.GetServer( - config.Registry.Docker.HTTP.Net, config.Registry.Docker.HTTP.Addr), - "agent_server": fmt.Sprintf("127.0.0.1:%d", flags.AgentServerPort), - "registry_backup": config.RegistryBackup}, - nginx.WithTLS(config.TLS)); err != nil { - stopHeartbeat() - log.Fatal(err) + runErr := waitForShutdown(ctx, errCh) + // Drain in-flight HTTP requests before returning. cancel() is called first + // so Shutdown does not block indefinitely if we are exiting due to an error + // rather than a signal. + cancel() + if err := httpServer.Shutdown(context.Background()); err != nil { + log.Errorf("agent server shutdown: %s", err) + } + return runErr +} + +// waitForShutdown blocks until ctx is cancelled or an error arrives on errCh. +// Goroutines always send to errCh before calling cancel(), so by the time +// ctx.Done() is observed the error is already buffered and the non-blocking +// drain will always retrieve it. +func waitForShutdown(ctx context.Context, errCh <-chan error) error { + select { + case <-ctx.Done(): + select { + case err := <-errCh: + return err + default: + } + log.Infof("shutting down: %s", ctx.Err()) + return nil + case err := <-errCh: + return err } } diff --git a/agent/cmd/cmd_test.go b/agent/cmd/cmd_test.go index 9ae7d5c39..51b6cc760 100644 --- a/agent/cmd/cmd_test.go +++ b/agent/cmd/cmd_test.go @@ -1,8 +1,11 @@ package cmd import ( + "context" + "errors" "flag" - "fmt" + "net" + "net/http" "os" "path/filepath" "runtime" @@ -112,7 +115,7 @@ func TestRunValidation(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { assert.PanicsWithValue(t, test.panic, func() { - Run(&test.flags) + _ = Run(context.Background(), &test.flags) //nolint:errcheck }) }) } @@ -131,7 +134,8 @@ func TestRunUsesProvidedConfig(t *testing.T) { called := false assert.PanicsWithValue(t, sentinel, func() { - Run( + _ = Run( //nolint:errcheck + context.Background(), flags, WithConfig(Config{}), WithMetrics(tally.NewTestScope("", nil)), @@ -146,7 +150,7 @@ func TestRunUsesProvidedConfig(t *testing.T) { assert.True(t, called, "effect should be invoked") } -func TestRunPanicsWhenConfigLoadFails(t *testing.T) { +func TestRunReturnsErrorWhenConfigLoadFails(t *testing.T) { missing := filepath.Join(t.TempDir(), "missing.yaml") flags := &Flags{ @@ -156,18 +160,17 @@ func TestRunPanicsWhenConfigLoadFails(t *testing.T) { ConfigFile: missing, } - expected := fmt.Sprintf("open %s: no such file or directory", missing) - - assert.PanicsWithError(t, expected, func() { - Run( - flags, - WithMetrics(tally.NewTestScope("", nil)), - WithLogger(zap.NewNop()), - ) - }) + err := Run( + context.Background(), + flags, + WithMetrics(tally.NewTestScope("", nil)), + WithLogger(zap.NewNop()), + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "no such file or directory") } -func TestRunPanicsWhenSecretsLoadFails(t *testing.T) { +func TestRunReturnsErrorWhenSecretsLoadFails(t *testing.T) { _, filename, _, ok := runtime.Caller(0) require.True(t, ok) @@ -185,15 +188,14 @@ func TestRunPanicsWhenSecretsLoadFails(t *testing.T) { SecretsFile: missingSecrets, } - expected := fmt.Sprintf("open %s: no such file or directory", missingSecrets) - - assert.PanicsWithError(t, expected, func() { - Run( - flags, - WithMetrics(tally.NewTestScope("", nil)), - WithLogger(zap.NewNop()), - ) - }) + err := Run( + context.Background(), + flags, + WithMetrics(tally.NewTestScope("", nil)), + WithLogger(zap.NewNop()), + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "no such file or directory") } func TestValidateRequiredPorts(t *testing.T) { @@ -288,3 +290,61 @@ func (t clockTicker) Chan() <-chan time.Time { func (t clockTicker) Stop() { t.ticker.Stop() } + +func TestWaitForShutdown_ExternalCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 2) + cancel() + err := waitForShutdown(ctx, errCh) + assert.NoError(t, err) +} + +func TestWaitForShutdown_ErrorReceived(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error, 2) + sentinel := errors.New("internal failure") + errCh <- sentinel + err := waitForShutdown(ctx, errCh) + assert.Equal(t, sentinel, err) +} + +func TestWaitForShutdown_ErrorWinsWhenBothReady(t *testing.T) { + // When cancel() and errCh <- err happen together (the common internal-error + // path), the error must always be returned — never silently swallowed as nil. + sentinel := errors.New("goroutine failure") + for range 100 { + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 2) + cancel() + errCh <- sentinel + err := waitForShutdown(ctx, errCh) + require.Equal(t, sentinel, err, "error must not be swallowed when both ctx and errCh are ready") + } +} + +// TestHTTPServerGracefulShutdown verifies the pattern used in Run: +// ListenAndServe stops cleanly when Shutdown is called, and http.ErrServerClosed +// is not surfaced as a fatal error. +func TestHTTPServerGracefulShutdown(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})} + + serveErr := make(chan error, 1) + go func() { + serveErr <- srv.Serve(ln) + }() + + // Shutdown should cause Serve to return http.ErrServerClosed — not a real error. + require.NoError(t, srv.Shutdown(context.Background())) + + err = <-serveErr + assert.Equal(t, http.ErrServerClosed, err, "Serve should return ErrServerClosed after Shutdown") + + // Verify the ErrServerClosed guard used in Run filters it correctly. + if err != nil && err != http.ErrServerClosed { + t.Fatal("guard would have incorrectly treated ErrServerClosed as a fatal error") + } +} diff --git a/agent/main.go b/agent/main.go index 636be0831..e44eb88c8 100644 --- a/agent/main.go +++ b/agent/main.go @@ -14,12 +14,23 @@ package main import ( + "context" + "log" + "os" + "os/signal" + "syscall" + "github.com/uber/kraken/agent/cmd" "github.com/uber/kraken/lib/dockerregistry" ) func main() { - cmd.Run(cmd.ParseFlags(), cmd.WithEffect(func() { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := cmd.Run(ctx, cmd.ParseFlags(), cmd.WithEffect(func() { dockerregistry.RegisterKrakenStorageDriver() - })) + })); err != nil { + log.Fatal(err) + } } diff --git a/nginx/nginx.go b/nginx/nginx.go index cd77d5047..b3377b3ab 100644 --- a/nginx/nginx.go +++ b/nginx/nginx.go @@ -15,12 +15,14 @@ package nginx import ( "bytes" + "context" "errors" "fmt" "os" "os/exec" "path" "path/filepath" + "syscall" "text/template" "github.com/uber/kraken/nginx/config" @@ -168,6 +170,13 @@ func WithTLS(tls httputil.TLSConfig) Option { // Run injects params into an nginx configuration template and runs it. func Run(config Config, params map[string]interface{}, opts ...Option) error { + return RunContext(context.Background(), config, params, opts...) +} + +// RunContext is like Run but sends SIGQUIT to the nginx process when ctx is +// cancelled, triggering nginx's graceful shutdown (drain in-flight connections) +// rather than an immediate kill. +func RunContext(ctx context.Context, config Config, params map[string]interface{}, opts ...Option) error { if err := config.applyDefaults(); err != nil { return fmt.Errorf("invalid config: %s", err) } @@ -240,7 +249,19 @@ func Run(config Config, params map[string]interface{}, opts ...Option) error { cmd := exec.Command(args[0], args[1:]...) cmd.Stdout = stdout cmd.Stderr = stdout - return cmd.Run() + if err := cmd.Start(); err != nil { + return fmt.Errorf("start nginx: %s", err) + } + + go func() { + <-ctx.Done() + if cmd.Process != nil { + if err := cmd.Process.Signal(syscall.SIGQUIT); err != nil { + log.Errorf("nginx SIGQUIT: %s", err) + } + } + }() + return cmd.Wait() } func populateTemplate(tmpl string, args map[string]interface{}) ([]byte, error) {