diff --git a/cmd/internal/flags/flags_test.go b/cmd/internal/flags/flags_test.go index 9b3430e..50b9ed1 100644 --- a/cmd/internal/flags/flags_test.go +++ b/cmd/internal/flags/flags_test.go @@ -21,8 +21,65 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" + + "github.com/go-gremlins/gremlins/internal/configuration" ) +type cliCase struct { + want any + getResult func(string) any + flag Flag + args []string +} + +func TestSetBindsTypedValueFromCLI(t *testing.T) { + testCases := []cliCase{ + { + flag: Flag{ + Name: "threshold-efficacy", + CfgKey: "unleash.threshold.efficacy", + DefaultV: float64(0), + Usage: "test usage", + }, + args: []string{"--threshold-efficacy", "50"}, + getResult: func(k string) any { return configuration.Get[float64](k) }, + want: float64(50), + }, + { + flag: Flag{ + Name: "workers", + CfgKey: "unleash.workers", + DefaultV: 0, + Usage: "test usage", + }, + args: []string{"--workers", "4"}, + getResult: func(k string) any { return configuration.Get[int](k) }, + want: 4, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.flag.Name, func(t *testing.T) { + defer viper.Reset() + + cmd := &cobra.Command{Use: "test", Run: func(_ *cobra.Command, _ []string) {}} + // #nosec G601 - We are in tests, we don't care + if err := Set(cmd, &tc.flag); err != nil { + t.Fatal("Set should not fail") + } + cmd.SetArgs(tc.args) + if err := cmd.Execute(); err != nil { + t.Fatal("Execute should not fail") + } + + if got := tc.getResult(tc.flag.CfgKey); got != tc.want { + t.Errorf("expected configuration.Get(%q) to be %T(%v), got %T(%v)", tc.flag.CfgKey, tc.want, tc.want, got, got) + } + }) + } +} + type unsupportedType int type testCase struct { diff --git a/cmd/unleash_test.go b/cmd/unleash_test.go index 0e7b637..6cfda59 100644 --- a/cmd/unleash_test.go +++ b/cmd/unleash_test.go @@ -22,6 +22,9 @@ import ( "strings" "testing" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "github.com/go-gremlins/gremlins/internal/configuration" "github.com/go-gremlins/gremlins/internal/mutator" ) @@ -201,3 +204,48 @@ func TestUnleash(t *testing.T) { } } } + +func TestUnleashFlagsPropagateToConfiguration(t *testing.T) { + c, err := newUnleashCmd(context.Background()) + if err != nil { + t.Fatal("newUnleashCmd should not fail") + } + c.cmd.RunE = func(_ *cobra.Command, _ []string) error { return nil } + c.cmd.SetArgs([]string{"--threshold-efficacy", "50", "--threshold-mcover", "25", "--workers", "4"}) + if err := c.cmd.Execute(); err != nil { + t.Fatal("Execute should not fail") + } + + testCases := []struct { + got any + want any + key string + }{ + { + key: configuration.UnleashThresholdEfficacyKey, + got: configuration.Get[float64](configuration.UnleashThresholdEfficacyKey), + want: float64(50), + }, + { + key: configuration.UnleashThresholdMCoverageKey, + got: configuration.Get[float64](configuration.UnleashThresholdMCoverageKey), + want: float64(25), + }, + { + key: configuration.UnleashWorkersKey, + got: configuration.Get[int](configuration.UnleashWorkersKey), + want: 4, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.key, func(t *testing.T) { + if tc.got != tc.want { + t.Errorf("expected %q to be %v, got %v", tc.key, tc.want, tc.got) + } + }) + } + + viper.Reset() +} diff --git a/go.mod b/go.mod index 69a8373..9b255e3 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b github.com/hectane/go-acl v0.0.0-20230122075934-ca0b05cb1adb github.com/mitchellh/go-homedir v1.1.0 + github.com/spf13/cast v1.10.0 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 github.com/spf13/viper v1.21.0 @@ -26,7 +27,6 @@ require ( github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect github.com/spf13/afero v1.15.0 // indirect - github.com/spf13/cast v1.10.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/sys v0.42.0 // indirect diff --git a/internal/configuration/configuration.go b/internal/configuration/configuration.go index cf13433..a06c77a 100644 --- a/internal/configuration/configuration.go +++ b/internal/configuration/configuration.go @@ -26,6 +26,7 @@ import ( "sync" "github.com/mitchellh/go-homedir" + "github.com/spf13/cast" "github.com/spf13/viper" "github.com/go-gremlins/gremlins/internal/mutator" @@ -180,11 +181,37 @@ func Set[T any](k string, v T) { } // Get offers synchronised access to Viper. +// +// Values bound from pflag may surface through viper.Get as their string +// representation rather than the native Go type, so a plain type assertion +// would silently yield the zero value. Coerce via spf13/cast for the +// numeric and string types Gremlins uses, falling back to the assertion +// for anything else (slices, custom types). func Get[T any](k string) T { - var r T mutex.RLock() defer mutex.RUnlock() - r, _ = viper.Get(k).(T) + + v := viper.Get(k) + var zero T + if v == nil { + return zero + } + + var coerced any + switch any(zero).(type) { + case float64: + coerced = cast.ToFloat64(v) + case int: + coerced = cast.ToInt(v) + case bool: + coerced = cast.ToBool(v) + case string: + coerced = cast.ToString(v) + default: + coerced = v + } + + r, _ := coerced.(T) return r } diff --git a/internal/configuration/configuration_test.go b/internal/configuration/configuration_test.go index 9b7feaf..12c3c3c 100644 --- a/internal/configuration/configuration_test.go +++ b/internal/configuration/configuration_test.go @@ -24,11 +24,67 @@ import ( "github.com/google/go-cmp/cmp" "github.com/mitchellh/go-homedir" + "github.com/spf13/pflag" "github.com/spf13/viper" "github.com/go-gremlins/gremlins/internal/mutator" ) +func TestGetFromBoundPFlag(t *testing.T) { + testCases := []struct { + register func(fs *pflag.FlagSet, name string) + check func(t *testing.T, cfgKey string) + name string + flagName string + cfgKey string + args []string + }{ + { + name: "float64", + flagName: "efficacy", + cfgKey: "unleash.threshold.efficacy", + args: []string{"--efficacy", "50"}, + register: func(fs *pflag.FlagSet, name string) { fs.Float64(name, 0, "") }, + check: func(t *testing.T, cfgKey string) { + t.Helper() + if got := Get[float64](cfgKey); got != 50.0 { + t.Errorf("Get[float64](%q) = %v, want 50.0", cfgKey, got) + } + }, + }, + { + name: "int", + flagName: "workers", + cfgKey: "unleash.workers", + args: []string{"--workers", "4"}, + register: func(fs *pflag.FlagSet, name string) { fs.Int(name, 0, "") }, + check: func(t *testing.T, cfgKey string) { + t.Helper() + if got := Get[int](cfgKey); got != 4 { + t.Errorf("Get[int](%q) = %v, want 4", cfgKey, got) + } + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + fs := pflag.NewFlagSet("test", pflag.ContinueOnError) + tc.register(fs, tc.flagName) + if err := viper.BindPFlag(tc.cfgKey, fs.Lookup(tc.flagName)); err != nil { + t.Fatal("BindPFlag should not fail") + } + if err := fs.Parse(tc.args); err != nil { + t.Fatal("Parse should not fail") + } + + tc.check(t, tc.cfgKey) + viper.Reset() + }) + } +} + type envEntry struct { name string value string @@ -149,12 +205,9 @@ func TestConfigPaths(t *testing.T) { filepath.Join(home, ".gremlins"), ) - // Then module root + // Then module root, then current folder moduleRoot, _ := os.Getwd() - want = append(want, moduleRoot) - - // Last current folder - want = append(want, ".") + want = append(want, moduleRoot, ".") got := defaultConfigPaths() @@ -177,15 +230,13 @@ func TestConfigPaths(t *testing.T) { want = append(want, "/etc/gremlins") } - // Then $XDG_CONFIG_HOME and $HOME + // Then $XDG_CONFIG_HOME and $HOME, then current folder want = append(want, filepath.Join(home, ".config", "gremlins", "gremlins"), filepath.Join(home, ".gremlins"), + ".", ) - // Last current folder - want = append(want, ".") - got := defaultConfigPaths() if !cmp.Equal(got, want) { @@ -215,12 +266,9 @@ func TestConfigPaths(t *testing.T) { filepath.Join(customPath, "gremlins", "gremlins"), filepath.Join(home, ".gremlins")) - // Then Go module root + // Then Go module root, then current directory moduleRoot, _ := os.Getwd() - want = append(want, moduleRoot) - - // Last the current directory - want = append(want, ".") + want = append(want, moduleRoot, ".") got := defaultConfigPaths()