From 3d325709c8d03da72681f55e0d2d29e7a5b09a6d Mon Sep 17 00:00:00 2001 From: Michael Weber Date: Wed, 15 Apr 2026 14:05:19 -0500 Subject: [PATCH 01/23] Fix synced config application Unmarshal the full merged config in MergeConfiguration and apply that end state directly instead of routing a sparse diff back through the sync apply path. Keep applyPatch as the helper for real partial section diffs, and factor shared normalization/store helpers so both flows continue to validate and normalize definitions before persisting them. --- internal/config/sync.go | 175 +++++++++++++++++++++++++++++------ internal/config/sync_test.go | 157 ++++++++++++++++++++++++++++++- 2 files changed, 305 insertions(+), 27 deletions(-) diff --git a/internal/config/sync.go b/internal/config/sync.go index fc6abdd1..fa706fd0 100644 --- a/internal/config/sync.go +++ b/internal/config/sync.go @@ -58,26 +58,21 @@ func (c *Config) MergeConfiguration(config *RegistrationResponse) error { return err } - // Create a patch to see the differences between existing and new - incomingPatch, err := jsonpatch.CreateMergePatch(existingData, newData) + // Convert the merged configuration back to a struct so the in-memory + // config reflects the fully merged remote+local state rather than a + // sparse diff payload. Unmarshaling the sparse merge patch here would + // collapse omitted fields to zero values in typed structs. + var mergedConfig ConfigPatchRequest + err = json.Unmarshal(newData, &mergedConfig) if err != nil { - logrus.WithError(err).Errorln("Failed to create merge patch for configuration diffing") - return err - } - - // Convert patches back to structs - these are the NEW changes from the remote - // server that we need to apply to our existing configuration - var incomingDiff ConfigPatchRequest - err = json.Unmarshal(incomingPatch, &incomingDiff) - - if err != nil { - logrus.WithError(err).Errorln("Failed to unmarshal incoming patch") + logrus.WithError(err).Errorln("Failed to unmarshal merged configuration") return err } - // Add these new changes to our existing configuration - err = c.applyPatch(incomingDiff) + // Apply the desired end state directly. applyPatch remains the helper for + // callers that actually have partial section diffs to merge locally first. + err = c.applyMergedConfig(mergedConfig) if err != nil { logrus.WithError(err).Errorln("Failed to apply incoming configuration patch") @@ -139,6 +134,8 @@ func (c *Config) MergeConfiguration(config *RegistrationResponse) error { } func (c *Config) applyPatch(diff ConfigPatchRequest) error { + // applyPatch is the partial-patch helper: merge the incoming section diff + // with the current live section, then normalize and persist the result. // Apply role changes if diff.RoleConfig != nil { err := c.updateRoles(diff.RoleConfig) @@ -169,23 +166,149 @@ func (c *Config) applyPatch(diff ConfigPatchRequest) error { return nil } +// applyMergedConfig applies a fully merged server state. It normalizes each +// definitions map and stores it directly without re-merging the same sections. +func (c *Config) applyMergedConfig(config ConfigPatchRequest) error { + if config.RoleConfig != nil { + if err := c.storeRoleDefinitions(config.RoleConfig.Definitions); err != nil { + logrus.WithError(err).Errorln("Failed to apply merged role configuration") + return err + } + } + + if config.WorkflowConfig != nil { + if err := c.storeWorkflowDefinitions(config.WorkflowConfig.Definitions); err != nil { + logrus.WithError(err).Errorln("Failed to apply merged workflow configuration") + return err + } + } + + if config.ProviderConfig != nil { + if err := c.storeProviderDefinitions(config.ProviderConfig.Definitions); err != nil { + logrus.WithError(err).Errorln("Failed to apply merged provider configuration") + return err + } + } + + return nil +} + +func mergeConfigSection(current any, incoming any, out any) error { + currentData, err := json.Marshal(current) + if err != nil { + return err + } + + incomingData, err := json.Marshal(incoming) + if err != nil { + return err + } + + mergedData, err := jsonpatch.MergePatch(currentData, incomingData) + if err != nil { + return err + } + + return json.Unmarshal(mergedData, out) +} + func (c *Config) updateRoles(roleConfig *RoleConfig) error { - _, err := c.ApplyRoles([]*models.RoleDefinitions{{ - Roles: roleConfig.Definitions, - }}) - return err + c.mu.RLock() + current := RoleConfig{ + Path: c.Roles.Path, + URL: c.Roles.URL, + Vault: c.Roles.Vault, + Definitions: c.Roles.Definitions, + } + c.mu.RUnlock() + + var merged RoleConfig + if err := mergeConfigSection(current, *roleConfig, &merged); err != nil { + return err + } + + return c.storeRoleDefinitions(merged.Definitions) } func (c *Config) updateWorkflows(workflowConfig *WorkflowConfig) error { - _, err := c.ApplyWorkflows([]*models.WorkflowDefinitions{{ - Workflows: workflowConfig.Definitions, - }}) - return err + c.mu.RLock() + current := WorkflowConfig{ + Path: c.Workflows.Path, + URL: c.Workflows.URL, + Vault: c.Workflows.Vault, + Plugins: c.Workflows.Plugins, + Definitions: c.Workflows.Definitions, + } + c.mu.RUnlock() + + var merged WorkflowConfig + if err := mergeConfigSection(current, *workflowConfig, &merged); err != nil { + return err + } + + return c.storeWorkflowDefinitions(merged.Definitions) } func (c *Config) updateProviders(providerConfig *ProviderDefinitionsConfig) error { - _, err := c.ApplyProviders([]*models.ProviderDefinitions{{ - Providers: providerConfig.Definitions, + c.mu.RLock() + current := ProviderDefinitionsConfig{ + Path: c.Providers.Path, + URL: c.Providers.URL, + Vault: c.Providers.Vault, + Plugins: c.Providers.Plugins, + Definitions: c.Providers.Definitions, + } + c.mu.RUnlock() + + var merged ProviderDefinitionsConfig + if err := mergeConfigSection(current, *providerConfig, &merged); err != nil { + return err + } + + return c.storeProviderDefinitions(merged.Definitions) +} + +func (c *Config) storeRoleDefinitions(definitions map[string]models.Role) error { + defs, err := c.ApplyRoles([]*models.RoleDefinitions{{ + Roles: definitions, }}) - return err + if err != nil { + return err + } + + c.mu.Lock() + c.Roles.Definitions = defs + c.mu.Unlock() + + return nil +} + +func (c *Config) storeWorkflowDefinitions(definitions map[string]models.Workflow) error { + defs, err := c.ApplyWorkflows([]*models.WorkflowDefinitions{{ + Workflows: definitions, + }}) + if err != nil { + return err + } + + c.mu.Lock() + c.Workflows.Definitions = defs + c.mu.Unlock() + + return nil +} + +func (c *Config) storeProviderDefinitions(definitions map[string]models.ProviderConfig) error { + defs, err := c.ApplyProviders([]*models.ProviderDefinitions{{ + Providers: definitions, + }}) + if err != nil { + return err + } + + c.mu.Lock() + c.Providers.Definitions = defs + c.mu.Unlock() + + return nil } diff --git a/internal/config/sync_test.go b/internal/config/sync_test.go index b40e88e1..21bd992d 100644 --- a/internal/config/sync_test.go +++ b/internal/config/sync_test.go @@ -10,6 +10,7 @@ import ( "time" jsonpatch "github.com/evanphx/json-patch" + "github.com/serverlessworkflow/sdk-go/v3/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thand-io/agent/internal/models" @@ -116,6 +117,26 @@ func makeRegistrationResponse( return resp } +func makeTestWorkflow(name, description string) models.Workflow { + return models.Workflow{ + Name: name, + Description: description, + Enabled: true, + Workflow: &model.Workflow{ + Do: &model.TaskList{}, + }, + } +} + +func makeTestProvider(name, description string) models.ProviderConfig { + return models.ProviderConfig{ + Name: name, + Description: description, + Provider: "mock", + Enabled: true, + } +} + // waitForPatch waits for a single PATCH call on the channel, or times out. func waitForPatch(ch <-chan syncPatchCall, timeout time.Duration) (syncPatchCall, bool) { select { @@ -147,6 +168,11 @@ func TestMergeConfiguration_ServerSendsNewRoles(t *testing.T) { err := config.MergeConfiguration(reg) require.NoError(t, err) + role, exists := config.Roles.Definitions["admin"] + require.True(t, exists, "expected synced role to be stored locally") + assert.Equal(t, "admin", role.Name) + assert.True(t, role.Enabled) + _, ok := waitForPatch(patchCh, 5*time.Second) require.True(t, ok, "expected outgoing PATCH call") } @@ -157,7 +183,8 @@ func TestMergeConfiguration_ServerSendsUpdatedRole(t *testing.T) { // Local config has an existing role config := newSyncTestConfig(t, map[string]models.Role{ - "editor": {Name: "editor", Description: "Can edit", Enabled: true}, + "editor": {Name: "editor", Description: "Can edit", Enabled: true}, + "untouched": {Name: "untouched", Description: "Keep me", Enabled: true}, }, nil, nil, server.URL, ) @@ -173,6 +200,128 @@ func TestMergeConfiguration_ServerSendsUpdatedRole(t *testing.T) { err := config.MergeConfiguration(reg) require.NoError(t, err) + role, exists := config.Roles.Definitions["editor"] + require.True(t, exists, "expected synced role to be stored locally") + assert.Equal(t, "Can edit and publish", role.Description) + assert.True(t, role.Enabled) + assert.Contains(t, config.Roles.Definitions, "untouched") + + _, ok := waitForPatch(patchCh, 5*time.Second) + require.True(t, ok, "expected outgoing PATCH call") +} + +func TestMergeConfiguration_ServerSendsNewWorkflows(t *testing.T) { + server, patchCh := newSyncTestServer(t) + + config := newSyncTestConfig(t, nil, nil, nil, server.URL) + + reg := makeRegistrationResponse( + nil, + map[string]models.Workflow{ + "approval": makeTestWorkflow("approval", "Handles approvals"), + }, + nil, + ) + + err := config.MergeConfiguration(reg) + require.NoError(t, err) + + workflow, exists := config.Workflows.Definitions["approval"] + require.True(t, exists, "expected synced workflow to be stored locally") + assert.Equal(t, "Handles approvals", workflow.Description) + require.NotNil(t, workflow.Workflow) + + _, ok := waitForPatch(patchCh, 5*time.Second) + require.True(t, ok, "expected outgoing PATCH call") +} + +func TestMergeConfiguration_ServerSendsUpdatedWorkflow(t *testing.T) { + server, patchCh := newSyncTestServer(t) + + config := newSyncTestConfig(t, + nil, + map[string]models.Workflow{ + "existing": makeTestWorkflow("existing", "Existing workflow"), + "unchanged": makeTestWorkflow("unchanged", "Keep me"), + }, + nil, + server.URL, + ) + + reg := makeRegistrationResponse( + nil, + map[string]models.Workflow{ + "existing": makeTestWorkflow("existing", "Updated workflow"), + }, + nil, + ) + + err := config.MergeConfiguration(reg) + require.NoError(t, err) + + workflow, exists := config.Workflows.Definitions["existing"] + require.True(t, exists, "expected synced workflow to be stored locally") + assert.Equal(t, "Updated workflow", workflow.Description) + assert.Contains(t, config.Workflows.Definitions, "unchanged") + + _, ok := waitForPatch(patchCh, 5*time.Second) + require.True(t, ok, "expected outgoing PATCH call") +} + +func TestMergeConfiguration_ServerSendsNewProviders(t *testing.T) { + server, patchCh := newSyncTestServer(t) + + config := newSyncTestConfig(t, nil, nil, nil, server.URL) + + reg := makeRegistrationResponse( + nil, + nil, + map[string]models.ProviderConfig{ + "mock-primary": makeTestProvider("mock-primary", "Primary mock provider"), + }, + ) + + err := config.MergeConfiguration(reg) + require.NoError(t, err) + + provider, exists := config.Providers.Definitions["mock-primary"] + require.True(t, exists, "expected synced provider to be stored locally") + assert.Equal(t, "Primary mock provider", provider.Description) + assert.Equal(t, "mock", provider.Provider) + + _, ok := waitForPatch(patchCh, 5*time.Second) + require.True(t, ok, "expected outgoing PATCH call") +} + +func TestMergeConfiguration_ServerSendsUpdatedProvider(t *testing.T) { + server, patchCh := newSyncTestServer(t) + + config := newSyncTestConfig(t, + nil, + nil, + map[string]models.ProviderConfig{ + "mock-primary": makeTestProvider("mock-primary", "Old provider description"), + "mock-extra": makeTestProvider("mock-extra", "Keep me"), + }, + server.URL, + ) + + reg := makeRegistrationResponse( + nil, + nil, + map[string]models.ProviderConfig{ + "mock-primary": makeTestProvider("mock-primary", "Updated provider description"), + }, + ) + + err := config.MergeConfiguration(reg) + require.NoError(t, err) + + provider, exists := config.Providers.Definitions["mock-primary"] + require.True(t, exists, "expected synced provider to be stored locally") + assert.Equal(t, "Updated provider description", provider.Description) + assert.Contains(t, config.Providers.Definitions, "mock-extra") + _, ok := waitForPatch(patchCh, 5*time.Second) require.True(t, ok, "expected outgoing PATCH call") } @@ -218,6 +367,11 @@ func TestMergeConfiguration_PartialConfig_OnlyRoles(t *testing.T) { err := config.MergeConfiguration(reg) require.NoError(t, err) + role, exists := config.Roles.Definitions["new-role"] + require.True(t, exists, "expected synced role to be stored locally") + assert.Equal(t, "new-role", role.Name) + assert.Contains(t, config.Roles.Definitions, "existing") + _, ok := waitForPatch(patchCh, 5*time.Second) require.True(t, ok, "expected outgoing PATCH call") } @@ -461,6 +615,7 @@ func TestApplyPatch_AppliesRoles(t *testing.T) { err := config.applyPatch(diff) assert.NoError(t, err) + assert.Contains(t, config.Roles.Definitions, "new-role") } func TestApplyPatch_SkipsNilWorkflows(t *testing.T) { From 6621fa2cf4228ee0c90f8080a7b9550773d43e95 Mon Sep 17 00:00:00 2001 From: Michael Weber Date: Wed, 15 Apr 2026 14:13:38 -0500 Subject: [PATCH 02/23] Retry config sync on concurrent changes Snapshot the current config generation before building the merged sync view, normalize the merged role/workflow/provider definition maps off-lock, and only commit them if the generation is unchanged. Keep the retry logic scoped to MergeConfiguration, compare and commit definitions only, and detach the snapshot through JSON so stale retries do not alias nested state. Reloaded definitions now bump the generation counter, while broader nested-mutation cleanup remains tracked in #306. --- docs/development/index.md | 9 + internal/config/config.go | 3 + internal/config/model.go | 6 + internal/config/sync.go | 457 ++++++++++++++++++++++++++++------- internal/config/sync_test.go | 276 ++++++++++++++++++++- 5 files changed, 669 insertions(+), 82 deletions(-) diff --git a/docs/development/index.md b/docs/development/index.md index a5f846b7..24736c4d 100644 --- a/docs/development/index.md +++ b/docs/development/index.md @@ -9,3 +9,12 @@ description: Developer documentation for Thand Agent # Development Documentation for developers contributing to or extending the Thand Agent. + +## Config Mutation Invariant + +Configuration definition maps should be treated as immutable snapshots. +When config changes, prefer replacing whole entries or whole definition maps +instead of mutating nested state in place. Some older code paths still perform +mutation-prone updates; keep new code aligned with the invariant and track +cleanup of legacy exceptions in follow-up issues rather than extending them. +Current cleanup work is tracked in [#306](https://github.com/thand-io/agent/issues/306). diff --git a/internal/config/config.go b/internal/config/config.go index 3acef337..b39b4984 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -358,6 +358,7 @@ func (c *Config) ReloadConfig() error { logrus.Infoln("Loaded workflows from external source:", len(workflows)) c.mu.Lock() c.Workflows.Definitions = workflows + c.configGeneration++ c.mu.Unlock() } else { logrus.Warningln("No workflows loaded from external source") @@ -376,6 +377,7 @@ func (c *Config) ReloadConfig() error { logrus.Infoln("Loaded providers from external source:", len(providers)) c.mu.Lock() c.Providers.Definitions = providers + c.configGeneration++ c.mu.Unlock() } else { logrus.Warningln("No providers loaded from external source") @@ -392,6 +394,7 @@ func (c *Config) ReloadConfig() error { logrus.Infoln("Loaded roles from external source:", len(roles)) c.mu.Lock() c.Roles.Definitions = roles + c.configGeneration++ c.mu.Unlock() } else { logrus.Warningln("No roles loaded from external source") diff --git a/internal/config/model.go b/internal/config/model.go index 3d820544..4d3f9b46 100644 --- a/internal/config/model.go +++ b/internal/config/model.go @@ -64,6 +64,12 @@ type Config struct { logger thandLogger mu sync.RWMutex + // Incremented whenever synced config definitions actually change. + // Definition maps should be treated as immutable snapshots: callers should + // replace whole entries or whole maps rather than mutating nested state + // in place. Legacy mutation-prone paths are being tracked in issue #306. + configGeneration uint64 + // Cached services client initializeServiceClientOnce sync.Once servicesClient models.ServicesClientImpl diff --git a/internal/config/sync.go b/internal/config/sync.go index fa706fd0..51f8971b 100644 --- a/internal/config/sync.go +++ b/internal/config/sync.go @@ -1,6 +1,7 @@ package config import ( + "bytes" "encoding/json" "fmt" "net/http" @@ -12,6 +13,34 @@ import ( "github.com/thand-io/agent/internal/models" ) +const maxMergeConfigurationRetries = 3 + +type configPatchSnapshot struct { + generation uint64 + request ConfigPatchRequest + data []byte + roleDefinitionsJSON []byte + workflowDefinitionsJSON []byte + providerDefinitionsJSON []byte +} + +type buildMergedConfigResult struct { + config ConfigPatchRequest + outgoingPatch []byte +} + +type normalizedDefinitionsPatch struct { + roleDefinitions map[string]models.Role + roleDefinitionsJSON []byte + roleDefinitionsChanged bool + workflowDefinitions map[string]models.Workflow + workflowDefinitionsJSON []byte + workflowDefinitionsChanged bool + providerDefinitions map[string]models.ProviderConfig + providerDefinitionsJSON []byte + providerDefinitionsChanged bool +} + type ConfigPatchRequest struct { RoleConfig *RoleConfig `json:"roles,omitempty"` WorkflowConfig *WorkflowConfig `json:"workflows,omitempty"` @@ -33,64 +62,42 @@ func (c *Config) MergeConfiguration(config *RegistrationResponse) error { return err } - roles := c.GetRolesConfig() - workflows := c.GetWorkflowsConfig() - providers := c.GetProvidersConfig() - - existing := ConfigPatchRequest{ - RoleConfig: roles, - WorkflowConfig: workflows, - ProviderConfig: providers, - } - - existingData, err := json.Marshal(existing) - - if err != nil { - logrus.WithError(err).Errorln("Failed to marshal existing configuration for diffing") - return err - } - - // Apply the incoming changes over the existing configurations - newData, err := jsonpatch.MergePatch(existingData, incomingData) - - if err != nil { - logrus.WithError(err).Errorln("Failed to create merge patch for configuration diffing") - return err - } - - // Convert the merged configuration back to a struct so the in-memory - // config reflects the fully merged remote+local state rather than a - // sparse diff payload. Unmarshaling the sparse merge patch here would - // collapse omitted fields to zero values in typed structs. - var mergedConfig ConfigPatchRequest - err = json.Unmarshal(newData, &mergedConfig) - - if err != nil { - logrus.WithError(err).Errorln("Failed to unmarshal merged configuration") - return err - } - - // Apply the desired end state directly. applyPatch remains the helper for - // callers that actually have partial section diffs to merge locally first. - err = c.applyMergedConfig(mergedConfig) - - if err != nil { - logrus.WithError(err).Errorln("Failed to apply incoming configuration patch") - return err - } + outgoingPatch, err := c.applyMergedConfigWithRetries(func(snapshot *configPatchSnapshot) (*buildMergedConfigResult, error) { + // Apply the incoming changes over the existing configurations. + newData, err := jsonpatch.MergePatch(snapshot.data, incomingData) + if err != nil { + logrus.WithError(err).Errorln("Failed to create merge patch for configuration diffing") + return nil, err + } - // Now we need to figure out what changes exist on the local system that need to - // be sent back to the server + // Convert the merged configuration back to a struct so the in-memory + // config reflects the fully merged remote+local state rather than a + // sparse diff payload. Unmarshaling the sparse merge patch here would + // collapse omitted fields to zero values in typed structs. + var mergedConfig ConfigPatchRequest + err = json.Unmarshal(newData, &mergedConfig) + if err != nil { + logrus.WithError(err).Errorln("Failed to unmarshal merged configuration") + return nil, err + } - outgoingPatch, err := jsonpatch.CreateMergePatch(incomingData, existingData) + // Now we need to figure out what changes exist on the local system that need to + // be sent back to the server + outgoingPatch, err := jsonpatch.CreateMergePatch(incomingData, snapshot.data) + if err != nil { + logrus.WithError(err).Errorln("Failed to create merge patch for configuration diffing") + return nil, err + } + return &buildMergedConfigResult{ + config: mergedConfig, + outgoingPatch: outgoingPatch, + }, nil + }) if err != nil { - logrus.WithError(err).Errorln("Failed to create merge patch for configuration diffing") return err } - // Send the outgoing changes back to the server to update its configuration - go func() { logrus.Debugln("Sending configuration updates back to server") @@ -133,10 +140,102 @@ func (c *Config) MergeConfiguration(config *RegistrationResponse) error { } +func (c *Config) applyMergedConfigWithRetries(build func(snapshot *configPatchSnapshot) (*buildMergedConfigResult, error)) ([]byte, error) { + for attempt := range maxMergeConfigurationRetries { + snapshot, err := c.snapshotConfigPatch() + if err != nil { + logrus.WithError(err).Errorln("Failed to marshal existing configuration for diffing") + return nil, err + } + + result, err := build(snapshot) + if err != nil { + return nil, err + } + + applied, err := c.applyMergedConfigWithSnapshot(snapshot, result.config) + if err != nil { + logrus.WithError(err).Errorln("Failed to apply incoming merged configuration") + return nil, err + } + if applied { + return result.outgoingPatch, nil + } + + logrus.WithField("attempt", attempt+1).Infoln("Configuration changed during merged sync apply, retrying") + } + + logrus.WithField("attempts", maxMergeConfigurationRetries).Warnln("Configuration changed during every merged sync attempt") + return nil, fmt.Errorf("configuration changed during merge after %d attempts", maxMergeConfigurationRetries) +} + +func (c *Config) snapshotConfigPatch() (*configPatchSnapshot, error) { + c.mu.RLock() + snapshot := ConfigPatchRequest{ + RoleConfig: &RoleConfig{ + Path: c.Roles.Path, + URL: c.Roles.URL, + Vault: c.Roles.Vault, + Definitions: c.Roles.Definitions, + }, + WorkflowConfig: &WorkflowConfig{ + Path: c.Workflows.Path, + URL: c.Workflows.URL, + Vault: c.Workflows.Vault, + Plugins: c.Workflows.Plugins, + Definitions: c.Workflows.Definitions, + }, + ProviderConfig: &ProviderDefinitionsConfig{ + Path: c.Providers.Path, + URL: c.Providers.URL, + Vault: c.Providers.Vault, + Plugins: c.Providers.Plugins, + Definitions: c.Providers.Definitions, + }, + } + generation := c.configGeneration + + data, err := json.Marshal(snapshot) + c.mu.RUnlock() + if err != nil { + return nil, err + } + + // Keep sync retries isolated from in-place nested mutations by detaching the + // snapshot through the same JSON representation used for merge-patch diffing. + var detached ConfigPatchRequest + if err := json.Unmarshal(data, &detached); err != nil { + return nil, err + } + + roleDefinitionsJSON, err := marshalJSON(detachedRoleDefinitions(detached.RoleConfig)) + if err != nil { + return nil, err + } + + workflowDefinitionsJSON, err := marshalJSON(detachedWorkflowDefinitions(detached.WorkflowConfig)) + if err != nil { + return nil, err + } + + providerDefinitionsJSON, err := marshalJSON(detachedProviderDefinitions(detached.ProviderConfig)) + if err != nil { + return nil, err + } + + return &configPatchSnapshot{ + generation: generation, + request: detached, + data: data, + roleDefinitionsJSON: roleDefinitionsJSON, + workflowDefinitionsJSON: workflowDefinitionsJSON, + providerDefinitionsJSON: providerDefinitionsJSON, + }, nil +} + func (c *Config) applyPatch(diff ConfigPatchRequest) error { // applyPatch is the partial-patch helper: merge the incoming section diff // with the current live section, then normalize and persist the result. - // Apply role changes if diff.RoleConfig != nil { err := c.updateRoles(diff.RoleConfig) if err != nil { @@ -145,7 +244,6 @@ func (c *Config) applyPatch(diff ConfigPatchRequest) error { } } - // Apply workflow changes if diff.WorkflowConfig != nil { err := c.updateWorkflows(diff.WorkflowConfig) if err != nil { @@ -154,7 +252,6 @@ func (c *Config) applyPatch(diff ConfigPatchRequest) error { } } - // Apply provider changes if diff.ProviderConfig != nil { err := c.updateProviders(diff.ProviderConfig) if err != nil { @@ -169,28 +266,124 @@ func (c *Config) applyPatch(diff ConfigPatchRequest) error { // applyMergedConfig applies a fully merged server state. It normalizes each // definitions map and stores it directly without re-merging the same sections. func (c *Config) applyMergedConfig(config ConfigPatchRequest) error { + snapshot, err := c.snapshotConfigPatch() + if err != nil { + return err + } + + applied, err := c.applyMergedConfigWithSnapshot(snapshot, config) + if err != nil { + return err + } + if !applied { + return fmt.Errorf("configuration changed while applying merged configuration") + } + + return nil +} + +func (c *Config) applyMergedConfigWithSnapshot(snapshot *configPatchSnapshot, config ConfigPatchRequest) (bool, error) { + normalized, err := c.normalizeMergedConfig(snapshot, config) + if err != nil { + return false, err + } + + return c.commitMergedDefinitions(normalized, snapshot.generation), nil +} + +func (c *Config) normalizeMergedConfig(snapshot *configPatchSnapshot, config ConfigPatchRequest) (*normalizedDefinitionsPatch, error) { + normalized := &normalizedDefinitionsPatch{} + if config.RoleConfig != nil { - if err := c.storeRoleDefinitions(config.RoleConfig.Definitions); err != nil { - logrus.WithError(err).Errorln("Failed to apply merged role configuration") - return err + mergedRoleDefinitionsJSON, err := marshalJSON(detachedRoleDefinitions(config.RoleConfig)) + if err != nil { + return nil, err + } + if !bytes.Equal(snapshot.roleDefinitionsJSON, mergedRoleDefinitionsJSON) { + defs, defsJSON, err := c.normalizeRoleDefinitions(detachedRoleDefinitions(config.RoleConfig)) + if err != nil { + logrus.WithError(err).Errorln("Failed to normalize merged role configuration") + return nil, err + } + normalized.roleDefinitionsJSON = defsJSON + normalized.roleDefinitionsChanged = !bytes.Equal(snapshot.roleDefinitionsJSON, defsJSON) + if normalized.roleDefinitionsChanged { + normalized.roleDefinitions = defs + } } } if config.WorkflowConfig != nil { - if err := c.storeWorkflowDefinitions(config.WorkflowConfig.Definitions); err != nil { - logrus.WithError(err).Errorln("Failed to apply merged workflow configuration") - return err + mergedWorkflowDefinitionsJSON, err := marshalJSON(detachedWorkflowDefinitions(config.WorkflowConfig)) + if err != nil { + return nil, err + } + if !bytes.Equal(snapshot.workflowDefinitionsJSON, mergedWorkflowDefinitionsJSON) { + defs, defsJSON, err := c.normalizeWorkflowDefinitions(detachedWorkflowDefinitions(config.WorkflowConfig)) + if err != nil { + logrus.WithError(err).Errorln("Failed to normalize merged workflow configuration") + return nil, err + } + normalized.workflowDefinitionsJSON = defsJSON + normalized.workflowDefinitionsChanged = !bytes.Equal(snapshot.workflowDefinitionsJSON, defsJSON) + if normalized.workflowDefinitionsChanged { + normalized.workflowDefinitions = defs + } } } if config.ProviderConfig != nil { - if err := c.storeProviderDefinitions(config.ProviderConfig.Definitions); err != nil { - logrus.WithError(err).Errorln("Failed to apply merged provider configuration") - return err + mergedProviderDefinitionsJSON, err := marshalJSON(detachedProviderDefinitions(config.ProviderConfig)) + if err != nil { + return nil, err + } + if !bytes.Equal(snapshot.providerDefinitionsJSON, mergedProviderDefinitionsJSON) { + defs, defsJSON, err := c.normalizeProviderDefinitions(detachedProviderDefinitions(config.ProviderConfig)) + if err != nil { + logrus.WithError(err).Errorln("Failed to normalize merged provider configuration") + return nil, err + } + normalized.providerDefinitionsJSON = defsJSON + normalized.providerDefinitionsChanged = !bytes.Equal(snapshot.providerDefinitionsJSON, defsJSON) + if normalized.providerDefinitionsChanged { + normalized.providerDefinitions = defs + } } } - return nil + return normalized, nil +} + +func (c *Config) commitMergedDefinitions(diff *normalizedDefinitionsPatch, expectedGeneration uint64) bool { + c.mu.Lock() + defer c.mu.Unlock() + + if c.configGeneration != expectedGeneration { + return false + } + + changed := false + if diff.roleDefinitionsChanged { + c.Roles.Definitions = diff.roleDefinitions + changed = true + } + if diff.workflowDefinitionsChanged { + c.Workflows.Definitions = diff.workflowDefinitions + changed = true + } + if diff.providerDefinitionsChanged { + c.Providers.Definitions = diff.providerDefinitions + changed = true + } + if changed { + c.configGeneration++ + } + + return true +} + +func marshalJSON(value any) ([]byte, error) { + return json.Marshal(value) } func mergeConfigSection(current any, incoming any, out any) error { @@ -269,46 +462,148 @@ func (c *Config) updateProviders(providerConfig *ProviderDefinitionsConfig) erro } func (c *Config) storeRoleDefinitions(definitions map[string]models.Role) error { - defs, err := c.ApplyRoles([]*models.RoleDefinitions{{ + defs, defsJSON, err := c.normalizeRoleDefinitions(definitions) + if err != nil { + return err + } + + return c.commitRoleDefinitions(defs, defsJSON) +} + +func (c *Config) storeWorkflowDefinitions(definitions map[string]models.Workflow) error { + defs, defsJSON, err := c.normalizeWorkflowDefinitions(definitions) + if err != nil { + return err + } + + return c.commitWorkflowDefinitions(defs, defsJSON) +} + +func (c *Config) storeProviderDefinitions(definitions map[string]models.ProviderConfig) error { + defs, defsJSON, err := c.normalizeProviderDefinitions(definitions) + if err != nil { + return err + } + + return c.commitProviderDefinitions(defs, defsJSON) +} + +func (c *Config) normalizeRoleDefinitions(definitions map[string]models.Role) (map[string]models.Role, []byte, error) { + defs, err := (&Config{}).ApplyRoles([]*models.RoleDefinitions{{ Roles: definitions, }}) if err != nil { - return err + return nil, nil, err } - c.mu.Lock() - c.Roles.Definitions = defs - c.mu.Unlock() + defsJSON, err := marshalJSON(defs) + if err != nil { + return nil, nil, err + } - return nil + return defs, defsJSON, nil } -func (c *Config) storeWorkflowDefinitions(definitions map[string]models.Workflow) error { - defs, err := c.ApplyWorkflows([]*models.WorkflowDefinitions{{ +func (c *Config) normalizeWorkflowDefinitions(definitions map[string]models.Workflow) (map[string]models.Workflow, []byte, error) { + defs, err := (&Config{mode: c.mode}).ApplyWorkflows([]*models.WorkflowDefinitions{{ Workflows: definitions, }}) if err != nil { - return err + return nil, nil, err } - c.mu.Lock() - c.Workflows.Definitions = defs - c.mu.Unlock() + defsJSON, err := marshalJSON(defs) + if err != nil { + return nil, nil, err + } - return nil + return defs, defsJSON, nil } -func (c *Config) storeProviderDefinitions(definitions map[string]models.ProviderConfig) error { - defs, err := c.ApplyProviders([]*models.ProviderDefinitions{{ +func (c *Config) normalizeProviderDefinitions(definitions map[string]models.ProviderConfig) (map[string]models.ProviderConfig, []byte, error) { + defs, err := (&Config{}).ApplyProviders([]*models.ProviderDefinitions{{ Providers: definitions, }}) + if err != nil { + return nil, nil, err + } + + defsJSON, err := marshalJSON(defs) + if err != nil { + return nil, nil, err + } + + return defs, defsJSON, nil +} + +func (c *Config) commitRoleDefinitions(defs map[string]models.Role, defsJSON []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + currentJSON, err := marshalJSON(c.Roles.Definitions) if err != nil { return err } + if bytes.Equal(currentJSON, defsJSON) { + return nil + } + c.Roles.Definitions = defs + c.configGeneration++ + return nil +} + +func (c *Config) commitWorkflowDefinitions(defs map[string]models.Workflow, defsJSON []byte) error { c.mu.Lock() - c.Providers.Definitions = defs - c.mu.Unlock() + defer c.mu.Unlock() + + currentJSON, err := marshalJSON(c.Workflows.Definitions) + if err != nil { + return err + } + if bytes.Equal(currentJSON, defsJSON) { + return nil + } + c.Workflows.Definitions = defs + c.configGeneration++ return nil } + +func (c *Config) commitProviderDefinitions(defs map[string]models.ProviderConfig, defsJSON []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + currentJSON, err := marshalJSON(c.Providers.Definitions) + if err != nil { + return err + } + if bytes.Equal(currentJSON, defsJSON) { + return nil + } + + c.Providers.Definitions = defs + c.configGeneration++ + return nil +} + +func detachedRoleDefinitions(config *RoleConfig) map[string]models.Role { + if config == nil { + return nil + } + return config.Definitions +} + +func detachedWorkflowDefinitions(config *WorkflowConfig) map[string]models.Workflow { + if config == nil { + return nil + } + return config.Definitions +} + +func detachedProviderDefinitions(config *ProviderDefinitionsConfig) map[string]models.ProviderConfig { + if config == nil { + return nil + } + return config.Definitions +} diff --git a/internal/config/sync_test.go b/internal/config/sync_test.go index 21bd992d..07c4f0ce 100644 --- a/internal/config/sync_test.go +++ b/internal/config/sync_test.go @@ -10,6 +10,7 @@ import ( "time" jsonpatch "github.com/evanphx/json-patch" + "github.com/hashicorp/go-version" "github.com/serverlessworkflow/sdk-go/v3/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -128,6 +129,16 @@ func makeTestWorkflow(name, description string) models.Workflow { } } +func makeNormalizedRole(name, description string) models.Role { + return models.Role{ + Version: version.Must(version.NewVersion("1.0")), + Identifier: name, + Name: name, + Description: description, + Enabled: true, + } +} + func makeTestProvider(name, description string) models.ProviderConfig { return models.ProviderConfig{ Name: name, @@ -137,6 +148,16 @@ func makeTestProvider(name, description string) models.ProviderConfig { } } +func makeTestProviderWithConfig(name, description string, config map[string]any) models.ProviderConfig { + provider := makeTestProvider(name, description) + basicConfig := models.BasicConfig{} + for key, value := range config { + basicConfig[key] = value + } + provider.Config = &basicConfig + return provider +} + // waitForPatch waits for a single PATCH call on the channel, or times out. func waitForPatch(ch <-chan syncPatchCall, timeout time.Duration) (syncPatchCall, bool) { select { @@ -330,7 +351,7 @@ func TestMergeConfiguration_IdenticalConfigs(t *testing.T) { server, patchCh := newSyncTestServer(t) roles := map[string]models.Role{ - "viewer": {Name: "viewer", Description: "Read-only", Enabled: true}, + "viewer": makeNormalizedRole("viewer", "Read-only"), } config := newSyncTestConfig(t, roles, nil, nil, server.URL) @@ -340,12 +361,41 @@ func TestMergeConfiguration_IdenticalConfigs(t *testing.T) { err := config.MergeConfiguration(reg) require.NoError(t, err) + assert.Equal(t, uint64(0), config.configGeneration, "identical sync should not advance the generation") // The outgoing goroutine still fires (with an empty or no-op patch) _, ok := waitForPatch(patchCh, 5*time.Second) require.True(t, ok, "expected outgoing PATCH call") } +func TestMergeConfiguration_MetadataOnlyRoleChangesAreIgnored(t *testing.T) { + server, patchCh := newSyncTestServer(t) + + config := newSyncTestConfig(t, + map[string]models.Role{ + "existing": makeNormalizedRole("existing", "same"), + }, + nil, nil, server.URL, + ) + config.Roles.Path = "./local-roles" + + reg := makeRegistrationResponse( + map[string]models.Role{ + "existing": makeNormalizedRole("existing", "same"), + }, + nil, nil, + ) + reg.Roles.Path = "./remote-roles" + + err := config.MergeConfiguration(reg) + require.NoError(t, err) + assert.Equal(t, uint64(0), config.configGeneration, "metadata-only sync changes should not advance generation") + assert.Equal(t, "./local-roles", config.Roles.Path, "definitions-only sync should not rewrite role metadata") + + _, ok := waitForPatch(patchCh, 5*time.Second) + require.True(t, ok, "expected outgoing PATCH call") +} + func TestMergeConfiguration_PartialConfig_OnlyRoles(t *testing.T) { server, patchCh := newSyncTestServer(t) @@ -616,6 +666,28 @@ func TestApplyPatch_AppliesRoles(t *testing.T) { err := config.applyPatch(diff) assert.NoError(t, err) assert.Contains(t, config.Roles.Definitions, "new-role") + assert.Equal(t, uint64(1), config.configGeneration) +} + +func TestApplyPatch_IdenticalRolesDoNotAdvanceGeneration(t *testing.T) { + config := newSyncTestConfig(t, + map[string]models.Role{ + "existing": makeNormalizedRole("existing", "same"), + }, + nil, nil, "", + ) + + diff := ConfigPatchRequest{ + RoleConfig: &RoleConfig{ + Definitions: map[string]models.Role{ + "existing": makeNormalizedRole("existing", "same"), + }, + }, + } + + err := config.applyPatch(diff) + require.NoError(t, err) + assert.Equal(t, uint64(0), config.configGeneration, "no-op apply should not advance generation") } func TestApplyPatch_SkipsNilWorkflows(t *testing.T) { @@ -639,6 +711,208 @@ func TestApplyPatch_SkipsNilWorkflows(t *testing.T) { assert.NoError(t, err) } +func TestApplyMergedConfigWithSnapshot_RejectsStaleGeneration(t *testing.T) { + config := newSyncTestConfig(t, + map[string]models.Role{ + "existing": makeNormalizedRole("existing", "before"), + }, + nil, nil, "", + ) + + snapshot, err := config.snapshotConfigPatch() + require.NoError(t, err) + + config.mu.Lock() + config.configGeneration++ + config.mu.Unlock() + + applied, err := config.applyMergedConfigWithSnapshot(snapshot, ConfigPatchRequest{ + RoleConfig: &RoleConfig{ + Definitions: map[string]models.Role{ + "existing": makeNormalizedRole("existing", "after"), + }, + }, + }) + require.NoError(t, err) + assert.False(t, applied, "expected generation mismatch to reject the stale merged apply") + assert.Equal(t, "before", config.Roles.Definitions["existing"].Description) +} + +func TestApplyMergedConfigWithRetries_RetriesAndSucceeds(t *testing.T) { + config := newSyncTestConfig(t, + map[string]models.Role{ + "existing": makeNormalizedRole("existing", "before"), + }, + nil, nil, "", + ) + + attempts := 0 + outgoingPatch, err := config.applyMergedConfigWithRetries(func(snapshot *configPatchSnapshot) (*buildMergedConfigResult, error) { + attempts++ + + if attempts == 1 { + config.mu.Lock() + config.configGeneration++ + config.mu.Unlock() + } + + return &buildMergedConfigResult{ + config: ConfigPatchRequest{ + RoleConfig: &RoleConfig{ + Definitions: map[string]models.Role{ + "existing": makeNormalizedRole("existing", "after"), + }, + }, + }, + outgoingPatch: []byte(`{"roles":{}}`), + }, nil + }) + require.NoError(t, err) + assert.Equal(t, 2, attempts, "expected one retry before success") + assert.JSONEq(t, `{"roles":{}}`, string(outgoingPatch)) + assert.Equal(t, "after", config.Roles.Definitions["existing"].Description) + assert.Equal(t, uint64(2), config.configGeneration) +} + +func TestSnapshotConfigPatch_DetachesRoleSlices(t *testing.T) { + config := newSyncTestConfig(t, + map[string]models.Role{ + "editor": { + Name: "editor", + Providers: []string{"aws-prod"}, + Permissions: models.RolePermissions{ + Allow: models.RoleStatements{{ + Operations: []string{"s3:GetObject"}, + }}, + }, + Enabled: true, + }, + }, + nil, nil, "", + ) + + snapshot, err := config.snapshotConfigPatch() + require.NoError(t, err) + + role := config.Roles.Definitions["editor"] + role.Providers[0] = "gcp-prod" + role.Permissions.Allow[0].Operations[0] = "storage.objects.get" + + snapRole, exists := snapshot.request.RoleConfig.Definitions["editor"] + require.True(t, exists) + assert.Equal(t, "aws-prod", snapRole.Providers[0]) + assert.Equal(t, "s3:GetObject", snapRole.Permissions.Allow[0].Operations[0]) +} + +func TestSnapshotConfigPatch_DetachesWorkflowDefinitions(t *testing.T) { + config := newSyncTestConfig(t, + nil, + map[string]models.Workflow{ + "approval": makeTestWorkflow("approval", "original"), + }, + nil, "", + ) + + snapshot, err := config.snapshotConfigPatch() + require.NoError(t, err) + + workflow := config.Workflows.Definitions["approval"] + workflow.Workflow.Do = nil + + snapWorkflow, exists := snapshot.request.WorkflowConfig.Definitions["approval"] + require.True(t, exists) + require.NotNil(t, snapWorkflow.Workflow) + assert.NotNil(t, snapWorkflow.Workflow.Do) +} + +func TestSnapshotConfigPatch_DetachesProviderConfig(t *testing.T) { + config := &Config{ + Providers: ProviderDefinitionsConfig{ + Definitions: map[string]models.ProviderConfig{ + "mock-primary": makeTestProviderWithConfig("mock-primary", "primary", map[string]any{ + "region": "us-east-1", + }), + }, + }, + } + + snapshot, err := config.snapshotConfigPatch() + require.NoError(t, err) + + provider := config.Providers.Definitions["mock-primary"] + require.NotNil(t, provider.Config) + provider.Config.SetKeyWithValue("region", "eu-west-1") + + snapProvider, exists := snapshot.request.ProviderConfig.Definitions["mock-primary"] + require.True(t, exists) + require.NotNil(t, snapProvider.Config) + region, ok := snapProvider.Config.GetString("region") + require.True(t, ok) + assert.Equal(t, "us-east-1", region) +} + +func TestKnownGap_NestedProviderMutationsShouldAdvanceConfigGeneration(t *testing.T) { + t.Skip("expected failure until #306: nested provider config mutations bypass configGeneration") + + config := &Config{ + Providers: ProviderDefinitionsConfig{ + Definitions: map[string]models.ProviderConfig{ + "mock-primary": makeTestProviderWithConfig("mock-primary", "primary", map[string]any{ + "region": "us-east-1", + }), + }, + }, + } + + provider := config.Providers.Definitions["mock-primary"] + require.NotNil(t, provider.Config) + provider.Config.SetKeyWithValue("region", "eu-west-1") + + assert.Equal(t, uint64(1), config.configGeneration, "nested provider mutations should participate in generation tracking") +} + +func TestKnownGap_NestedRoleMutationsShouldAdvanceConfigGeneration(t *testing.T) { + t.Skip("expected failure until #306: nested role mutations bypass configGeneration") + + config := newSyncTestConfig(t, + map[string]models.Role{ + "editor": { + Name: "editor", + Providers: []string{"aws-prod"}, + Permissions: models.RolePermissions{ + Allow: models.RoleStatements{{ + Operations: []string{"s3:GetObject"}, + }}, + }, + Enabled: true, + }, + }, + nil, nil, "", + ) + + role := config.Roles.Definitions["editor"] + role.Permissions.Allow[0].Operations[0] = "storage.objects.get" + + assert.Equal(t, uint64(1), config.configGeneration, "nested role mutations should participate in generation tracking") +} + +func TestKnownGap_NestedWorkflowMutationsShouldAdvanceConfigGeneration(t *testing.T) { + t.Skip("expected failure until #306: nested workflow mutations bypass configGeneration") + + config := newSyncTestConfig(t, + nil, + map[string]models.Workflow{ + "approval": makeTestWorkflow("approval", "original"), + }, + nil, "", + ) + + workflow := config.Workflows.Definitions["approval"] + workflow.Workflow.Do = nil + + assert.Equal(t, uint64(1), config.configGeneration, "nested workflow mutations should participate in generation tracking") +} + // --------------------------------------------------------------------------- // Tests for the merge-patch logic (JSON diffing correctness) // --------------------------------------------------------------------------- From 6d1258c9a2414ee372d15c43ae56c95888597e5c Mon Sep 17 00:00:00 2001 From: Michael Weber Date: Sun, 19 Apr 2026 23:16:04 -0500 Subject: [PATCH 03/23] test: add self-contained agent login e2e --- internal/daemon/server.go | 37 ++- internal/daemon/server_test.go | 39 +++ .../frontend/agent_login_e2e_test.go | 99 +++++++ test/integration/frontend/browser.go | 39 ++- .../frontend/infrastructure_test.go | 261 ++++++++++++++++-- .../testdata/agent-login/providers.yaml | 17 ++ .../frontend/testdata/agent-login/roles.yaml | 2 + .../testdata/agent-login/workflow.yaml | 2 + 8 files changed, 462 insertions(+), 34 deletions(-) create mode 100644 internal/daemon/server_test.go create mode 100644 test/integration/frontend/agent_login_e2e_test.go create mode 100644 test/integration/frontend/testdata/agent-login/providers.yaml create mode 100644 test/integration/frontend/testdata/agent-login/roles.yaml create mode 100644 test/integration/frontend/testdata/agent-login/workflow.yaml diff --git a/internal/daemon/server.go b/internal/daemon/server.go index d15bb220..a9239bb8 100644 --- a/internal/daemon/server.go +++ b/internal/daemon/server.go @@ -620,7 +620,7 @@ func (s *Server) apiConfigurationHandler(c *gin.Context) { workflows := []string{} // TODO: populate workflows list activities := []string{} // TODO: populate activities list - // For agent / client we show the local server for discvoery. + // For agent / client we show the local server for discovery. baseUrl := s.Config.GetLocalServerUrl() if s.Config.IsServer() { @@ -633,9 +633,14 @@ func (s *Server) apiConfigurationHandler(c *gin.Context) { capabilities["llm"] = services.HasLargeLanguageModel() capabilities["storage"] = services.HasStorage() - // However, for server we show the login server as the main - // entry point for clients to connect to - baseUrl = s.Config.GetLoginServerUrl() + // Prefer the origin the caller actually used so discovery stays + // reachable across local test hostnames and reverse proxies. + if requestBaseURL := getRequestBaseURL(c.Request); len(requestBaseURL) > 0 { + baseUrl = requestBaseURL + } else { + // Fall back to the configured login server entry point. + baseUrl = s.Config.GetLoginServerUrl() + } } response := gin.H{ @@ -674,6 +679,30 @@ func (s *Server) apiConfigurationHandler(c *gin.Context) { c.JSON(http.StatusOK, response) } +func getRequestBaseURL(req *http.Request) string { + if req == nil { + return "" + } + + host := strings.TrimSpace(req.Host) + if len(host) == 0 && req.URL != nil { + host = strings.TrimSpace(req.URL.Host) + } + if len(host) == 0 { + return "" + } + + scheme := "http" + if req.TLS != nil { + scheme = "https" + } + if forwardedProto := strings.TrimSpace(req.Header.Get("X-Forwarded-Proto")); len(forwardedProto) > 0 { + scheme = strings.TrimSpace(strings.Split(forwardedProto, ",")[0]) + } + + return fmt.Sprintf("%s://%s", scheme, host) +} + // readyHandler handles the readiness check endpoint // // @Summary Readiness check diff --git a/internal/daemon/server_test.go b/internal/daemon/server_test.go new file mode 100644 index 00000000..c2db3094 --- /dev/null +++ b/internal/daemon/server_test.go @@ -0,0 +1,39 @@ +package daemon + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/thand-io/agent/internal/config" +) + +func TestAPIConfigurationHandlerPrefersRequestOriginInServerMode(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + + cfg := config.DefaultConfig() + cfg.SetMode(config.ModeServer) + require.NoError(t, cfg.SetLoginServer("http://localhost:5225")) + + server := NewServer(cfg) + router := gin.New() + router.GET("/.well-known/api-configuration", server.apiConfigurationHandler) + + req := httptest.NewRequest(http.MethodGet, "/.well-known/api-configuration", nil) + req.Host = "thand.test:5225" + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response struct { + BaseURL string `json:"baseUrl"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &response)) + require.Equal(t, "http://thand.test:5225", response.BaseURL) +} diff --git a/test/integration/frontend/agent_login_e2e_test.go b/test/integration/frontend/agent_login_e2e_test.go new file mode 100644 index 00000000..109e2033 --- /dev/null +++ b/test/integration/frontend/agent_login_e2e_test.go @@ -0,0 +1,99 @@ +package ui_e2e + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/cookiejar" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/thand-io/agent/internal/models" + "github.com/thand-io/agent/test/integration/testinfra" +) + +// TestAgentLoginE2E exercises the browser login flow through a self-contained +// server+agent test environment. +// +// Run with Brave by setting CHROME_BIN, for example: +// +// CHROME_BIN="/Applications/Brave Browser.app/Contents/MacOS/Brave Browser" go test -v -run TestAgentLoginE2E ./integration/frontend/... +func TestAgentLoginE2E(t *testing.T) { + if testing.Short() { + t.Skip("Skipping UI integration test in short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + loader := testinfra.NewTestCaseLoader(nil, "testdata") + testCase, err := loader.LoadTestCase("agent-login") + require.NoError(t, err, "Failed to load agent-login test case") + + infra := SetupUITestInfrastructure(t, ctx, testCase, WithAgentContainer(), WithAgentLocalDefinitions()) + defer infra.Teardown() + + require.NotEmpty(t, infra.AgentEndpoint, "Agent endpoint should be configured") + + browser := NewBrowser(t, infra.AgentEndpoint) + defer browser.Close() + + err = browser.Login(ctx, "testuser@thand.io", "testpass123") + require.NoError(t, err, "Agent login should succeed") + + err = browser.WaitForAuthCallbackSuccess(ctx, 30*time.Second) + require.NoError(t, err, "Agent auth callback should register the session") + + sessions := waitForAgentSessions(t, infra.AgentEndpoint, "oidc-test", 30*time.Second) + require.Contains(t, sessions.Sessions, "oidc-test", "Agent should expose the authenticated session") + localSession := sessions.Sessions["oidc-test"] + require.False(t, localSession.IsExpired(), "OIDC session should remain active") +} + +func waitForAgentSessions(t *testing.T, agentEndpoint string, provider string, timeout time.Duration) *models.SessionsResponse { + t.Helper() + + jar, err := cookiejar.New(nil) + require.NoError(t, err, "Failed to create cookie jar for agent session polling") + + client := &http.Client{ + Timeout: 10 * time.Second, + Jar: jar, + } + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + sessions, err := fetchAgentSessions(client, agentEndpoint) + if err == nil { + if _, found := sessions.Sessions[provider]; found { + return sessions + } + } + + time.Sleep(500 * time.Millisecond) + } + + t.Fatalf("timed out waiting for agent session %q at %s", provider, agentEndpoint) + return nil +} + +func fetchAgentSessions(client *http.Client, agentEndpoint string) (*models.SessionsResponse, error) { + resp, err := client.Get(agentEndpoint + "/api/v1/sessions") + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected sessions status: %s", resp.Status) + } + + var sessions models.SessionsResponse + if err := json.NewDecoder(resp.Body).Decode(&sessions); err != nil { + return nil, err + } + + return &sessions, nil +} diff --git a/test/integration/frontend/browser.go b/test/integration/frontend/browser.go index eb13c202..cdd6265b 100644 --- a/test/integration/frontend/browser.go +++ b/test/integration/frontend/browser.go @@ -39,7 +39,11 @@ func NewBrowser(t *testing.T, baseURL string) *Browser { chromedp.Flag("disable-gpu", false), chromedp.Flag("no-sandbox", true), chromedp.Flag("disable-dev-shm-usage", true), - chromedp.Flag("host-resolver-rules", fmt.Sprintf("MAP %s 127.0.0.1", testinfra.KeycloakSharedHostname)), + chromedp.Flag("host-resolver-rules", fmt.Sprintf( + "MAP %s 127.0.0.1, MAP %s 127.0.0.1", + testinfra.KeycloakSharedHostname, + ThandSharedHostname, + )), ) // If CHROME_BIN is set (e.g., in CI environments), use it @@ -178,6 +182,39 @@ func (b *Browser) Login(ctx context.Context, username, password string) error { return b.LoginWithOIDC(ctx, username, password) } +// WaitForAuthCallbackSuccess waits for the auth callback page to report that +// session registration completed successfully. +func (b *Browser) WaitForAuthCallbackSuccess(ctx context.Context, timeout time.Duration) error { + b.t.Log("Waiting for auth callback success...") + + chromedpCtx, cancel := b.withCallerCtx(ctx) + defer cancel() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + var currentURL string + var bodyText string + + err := chromedp.Run(chromedpCtx, + chromedp.Location(¤tURL), + chromedp.Text(`body`, &bodyText, chromedp.ByQuery), + ) + if err == nil { + if strings.Contains(bodyText, "Session registered successfully!") { + b.t.Logf("Auth callback succeeded at URL: %s", currentURL) + return nil + } + if strings.Contains(bodyText, "Unable to register session automatically") { + return fmt.Errorf("auth callback failed to register session at URL %s", currentURL) + } + } + + time.Sleep(500 * time.Millisecond) + } + + return fmt.Errorf("timeout waiting for auth callback success") +} + // NavigateToElevatePage navigates to the elevate page func (b *Browser) NavigateToElevatePage(ctx context.Context) error { b.t.Log("Navigating to elevate page...") diff --git a/test/integration/frontend/infrastructure_test.go b/test/integration/frontend/infrastructure_test.go index 61d0295d..c71a331b 100644 --- a/test/integration/frontend/infrastructure_test.go +++ b/test/integration/frontend/infrastructure_test.go @@ -24,18 +24,48 @@ import ( const ( // ThandServerPort is the default HTTP port for Thand server. ThandServerPort = "5225" + // ThandSharedHostname is a synthetic hostname that both the browser and the + // agent container can resolve to the host-bound server port. + ThandSharedHostname = "thand.test" ) +type UISetupOption func(*uiSetupConfig) + +type uiSetupConfig struct { + startAgent bool + agentUsesLocalDefinitions bool +} + +func WithAgentContainer() UISetupOption { + return func(cfg *uiSetupConfig) { + cfg.startAgent = true + } +} + +func WithAgentLocalDefinitions() UISetupOption { + return func(cfg *uiSetupConfig) { + cfg.agentUsesLocalDefinitions = true + } +} + // UITestInfrastructure extends testinfra with UI-specific containers (Thand server). type UITestInfrastructure struct { *testinfra.TestInfrastructure // Thand Server - thandContainer testcontainers.Container - ThandEndpoint string - ThandAPIEndpoint string - allocatedHostPort int // pre-allocated host port for deterministic URLs - portListener net.Listener // held open until startThandServer closes it just before Docker binds + thandContainer testcontainers.Container + ThandEndpoint string + ThandAPIEndpoint string + ThandSharedEndpoint string + allocatedHostPort int // pre-allocated host port for deterministic URLs + portListener net.Listener // held open until startThandServer closes it just before Docker binds + + // Optional Thand Agent + agentContainer testcontainers.Container + AgentEndpoint string + AgentAPIEndpoint string + allocatedAgentHostPort int + agentPortListener net.Listener // providerEnvVars are injected into the server container so that ResolveConfig // can substitute ${ .VARNAME } jq expressions in provider definition YAML files. @@ -54,12 +84,25 @@ func (l *thandServerLogConsumer) Accept(log testcontainers.Log) { l.t.Logf("[thand-server] %s", string(log.Content)) } +type thandAgentLogConsumer struct { + t *testing.T +} + +func (l *thandAgentLogConsumer) Accept(log testcontainers.Log) { + l.t.Logf("[thand-agent] %s", string(log.Content)) +} + // SetupUITestInfrastructure creates and starts all containers needed for UI E2E tests. // It starts the base infrastructure (LocalStack, MailHog, Temporal) plus Keycloak and // optionally the Thand server container. -func SetupUITestInfrastructure(t *testing.T, ctx context.Context, testCase *TestCase) *UITestInfrastructure { +func SetupUITestInfrastructure(t *testing.T, ctx context.Context, testCase *TestCase, opts ...UISetupOption) *UITestInfrastructure { t.Helper() + setupCfg := &uiSetupConfig{} + for _, opt := range opts { + opt(setupCfg) + } + // Resolve Keycloak realm file path realmPath, err := filepath.Abs(filepath.Join("keycloak", "thand-test-realm.json")) require.NoError(t, err, "Failed to resolve Keycloak realm path") @@ -82,11 +125,24 @@ func SetupUITestInfrastructure(t *testing.T, ctx context.Context, testCase *Test infra.allocatedHostPort = listener.Addr().(*net.TCPAddr).Port infra.portListener = listener // keep open; closed in startThandServer just before container bind - // Set endpoint URLs early so createConfigDir can interpolate ${THAND_SERVER_URL} + // Set endpoint URLs early so createConfigDir can interpolate provider callback + // URLs correctly. Browser-facing flows use localhost, while thand.test exists + // only so containers can reach the host-bound server port. infra.ThandEndpoint = fmt.Sprintf("http://localhost:%d", infra.allocatedHostPort) infra.ThandAPIEndpoint = infra.ThandEndpoint + "/api/v1" + infra.ThandSharedEndpoint = fmt.Sprintf("http://%s:%d", ThandSharedHostname, infra.allocatedHostPort) t.Logf("Pre-allocated Thand server port: %d → %s", infra.allocatedHostPort, infra.ThandEndpoint) + if setupCfg.startAgent { + agentListener, agentErr := net.Listen("tcp", ":0") + require.NoError(t, agentErr, "Failed to allocate a free port for Thand agent") + infra.allocatedAgentHostPort = agentListener.Addr().(*net.TCPAddr).Port + infra.agentPortListener = agentListener + infra.AgentEndpoint = fmt.Sprintf("http://localhost:%d", infra.allocatedAgentHostPort) + infra.AgentAPIEndpoint = infra.AgentEndpoint + "/api/v1" + t.Logf("Pre-allocated Thand agent port: %d → %s", infra.allocatedAgentHostPort, infra.AgentEndpoint) + } + // Create temporary config directory with interpolated values. // Must be done after infrastructure is started so we have endpoints. infra.createConfigDir(t, testCase) @@ -98,6 +154,10 @@ func SetupUITestInfrastructure(t *testing.T, ctx context.Context, testCase *Test t.Log("Skipping Thand server start (empty test case)") } + if setupCfg.startAgent { + infra.startThandAgent(t, ctx, setupCfg.agentUsesLocalDefinitions) + } + return infra } @@ -149,9 +209,10 @@ func (infra *UITestInfrastructure) createConfigDir(t *testing.T, testCase *TestC infra.providerEnvVars["SAML_IDP_METADATA_URL_INTERNAL"] = samlMetaInternal } - // THAND_SERVER_URL uses localhost — the browser navigates here after Keycloak redirect. + // THAND_SERVER_URL must stay browser-facing so auth callbacks and session + // cookies remain on the same localhost origin as the rest of the UI flow. if infra.allocatedHostPort > 0 { - infra.providerEnvVars["THAND_SERVER_URL"] = fmt.Sprintf("http://localhost:%d", infra.allocatedHostPort) + infra.providerEnvVars["THAND_SERVER_URL"] = infra.ThandEndpoint } // Copy definition files verbatim — no string substitution needed. @@ -182,26 +243,7 @@ func (infra *UITestInfrastructure) startThandServer(t *testing.T, ctx context.Co t.Helper() t.Log("Starting Thand server container...") - // Detect host architecture to select the correct Linux binary. - // Docker Desktop on macOS M-series runs arm64 containers by default. - goarch := "amd64" - if arch := os.Getenv("GOARCH"); arch != "" { - goarch = arch - } else { - // runtime.GOARCH gives the architecture of the test process - goarch = runtime.GOARCH - } - - // Build path to Linux agent binary (needed for Alpine container) - agentBinaryPath := filepath.Join("..", "..", "..", "bin", fmt.Sprintf("thand-linux-%s", goarch)) - if _, err := os.Stat(agentBinaryPath); os.IsNotExist(err) { - // Fall back to amd64 if native arch binary doesn't exist - agentBinaryPath = filepath.Join("..", "..", "..", "bin", "thand-linux-amd64") - } - if _, err := os.Stat(agentBinaryPath); os.IsNotExist(err) { - t.Fatalf("Linux agent binary not found at %s. Run 'make build-linux-amd64' first.", agentBinaryPath) - } - t.Logf("Using agent binary: %s (arch: %s)", agentBinaryPath, goarch) + agentBinaryPath := resolveLinuxAgentBinaryPath(t) // Create config.yaml for the server in a separate temp dir serverConfigDir := filepath.Join(os.TempDir(), fmt.Sprintf("thand-server-config-%d", time.Now().UnixNano())) @@ -312,6 +354,7 @@ workflows: // Ensure host.docker.internal resolves on Linux Docker (not only Docker Desktop). hc.ExtraHosts = []string{ "host.docker.internal:host-gateway", + ThandSharedHostname + ":host-gateway", testinfra.KeycloakSharedHostname + ":host-gateway", } }, @@ -364,6 +407,157 @@ workflows: } } +func (infra *UITestInfrastructure) startThandAgent(t *testing.T, ctx context.Context, withLocalDefinitions bool) { + t.Helper() + t.Log("Starting Thand agent container...") + + agentBinaryPath := resolveLinuxAgentBinaryPath(t) + + agentConfigDir := filepath.Join(os.TempDir(), fmt.Sprintf("thand-agent-config-%d", time.Now().UnixNano())) + err := os.MkdirAll(agentConfigDir, 0755) + require.NoError(t, err, "Failed to create agent config directory") + infra.RegisterCleanup(func() { + os.RemoveAll(agentConfigDir) + }) + + agentPort := strconv.Itoa(infra.allocatedAgentHostPort) + + configYAML := fmt.Sprintf(` +mode: agent +secret: thand-e2e-test-configured +login: + endpoint: %s + base: / +server: + host: 0.0.0.0 + port: %s + security: + cors: + allowed_origins: + - %s + - %s +logging: + level: "debug" +`, infra.ThandSharedEndpoint, agentPort, infra.ThandEndpoint, infra.ThandSharedEndpoint) + + if withLocalDefinitions { + configYAML += ` +providers: + path: /app/definitions +roles: + path: /app/definitions +workflows: + path: /app/definitions +` + } + + err = os.WriteFile(filepath.Join(agentConfigDir, "config.yaml"), []byte(configYAML), 0644) + require.NoError(t, err, "Failed to write agent config.yaml") + + agentPortBinding := nat.Port(agentPort + "/tcp") + containerFiles := []testcontainers.ContainerFile{ + { + HostFilePath: agentBinaryPath, + ContainerFilePath: "/app/agent", + FileMode: 0755, + }, + { + HostFilePath: filepath.Join(agentConfigDir, "config.yaml"), + ContainerFilePath: "/app/config.yaml", + FileMode: 0644, + }, + } + if withLocalDefinitions { + for _, defFile := range []struct{ hostSrc, containerDst string }{ + {"providers.yaml", "/app/definitions/providers.yaml"}, + {"roles.yaml", "/app/definitions/roles.yaml"}, + {"workflows.yaml", "/app/definitions/workflows.yaml"}, + } { + hostPath := filepath.Join(infra.configDir, defFile.hostSrc) + if _, err := os.Stat(hostPath); err == nil { + containerFiles = append(containerFiles, testcontainers.ContainerFile{ + HostFilePath: hostPath, + ContainerFilePath: defFile.containerDst, + FileMode: 0644, + }) + } + } + } + + req := testcontainers.ContainerRequest{ + Image: "alpine:3.18", + Cmd: []string{ + "/app/agent", "agent", "--config", "/app/config.yaml", + }, + ExposedPorts: []string{agentPort + "/tcp"}, + Env: map[string]string{ + "THAND_MODE": "agent", + "THAND_LOG_LEVEL": "debug", + }, + Files: containerFiles, + LogConsumerCfg: &testcontainers.LogConsumerConfig{ + Consumers: []testcontainers.LogConsumer{ + &thandAgentLogConsumer{t: t}, + }, + }, + HostConfigModifier: func(hc *dockercontainer.HostConfig) { + hc.PortBindings = nat.PortMap{ + agentPortBinding: []nat.PortBinding{ + {HostIP: "0.0.0.0", HostPort: agentPort}, + }, + } + hc.ExtraHosts = []string{ + "host.docker.internal:host-gateway", + ThandSharedHostname + ":host-gateway", + testinfra.KeycloakSharedHostname + ":host-gateway", + } + }, + WaitingFor: wait.ForListeningPort(agentPortBinding). + WithStartupTimeout(120 * time.Second), + } + + if infra.agentPortListener != nil { + infra.agentPortListener.Close() + infra.agentPortListener = nil + } + + container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + require.NoError(t, err, "Failed to start Thand agent container") + infra.agentContainer = container + + host, err := container.Host(ctx) + require.NoError(t, err, "Failed to get Thand agent host") + mappedPort, err := container.MappedPort(ctx, agentPortBinding) + require.NoError(t, err, "Failed to get Thand agent port") + + actualEndpoint := fmt.Sprintf("http://%s:%s", host, mappedPort.Port()) + t.Logf("Thand agent started at %s (expected %s)", actualEndpoint, infra.AgentEndpoint) +} + +func resolveLinuxAgentBinaryPath(t *testing.T) string { + t.Helper() + + goarch := "amd64" + if arch := os.Getenv("GOARCH"); arch != "" { + goarch = arch + } else { + goarch = runtime.GOARCH + } + + agentBinaryPath := filepath.Join("..", "..", "..", "bin", fmt.Sprintf("thand-linux-%s", goarch)) + if _, err := os.Stat(agentBinaryPath); os.IsNotExist(err) { + agentBinaryPath = filepath.Join("..", "..", "..", "bin", "thand-linux-amd64") + } + if _, err := os.Stat(agentBinaryPath); os.IsNotExist(err) { + t.Fatalf("Linux agent binary not found at %s. Run 'make build-linux-amd64' first.", agentBinaryPath) + } + t.Logf("Using agent binary: %s (arch: %s)", agentBinaryPath, goarch) + return agentBinaryPath +} + // Teardown stops and removes all UI-specific containers, then delegates to base teardown. func (infra *UITestInfrastructure) Teardown() { terminateCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -374,8 +568,17 @@ func (infra *UITestInfrastructure) Teardown() { infra.portListener.Close() infra.portListener = nil } + if infra.agentPortListener != nil { + infra.agentPortListener.Close() + infra.agentPortListener = nil + } // Terminate UI-specific containers first + if infra.agentContainer != nil { + if err := infra.agentContainer.Terminate(terminateCtx); err != nil { + _ = err + } + } if infra.thandContainer != nil { if err := infra.thandContainer.Terminate(terminateCtx); err != nil { // Log warning but don't fail diff --git a/test/integration/frontend/testdata/agent-login/providers.yaml b/test/integration/frontend/testdata/agent-login/providers.yaml new file mode 100644 index 00000000..087d4660 --- /dev/null +++ b/test/integration/frontend/testdata/agent-login/providers.yaml @@ -0,0 +1,17 @@ +version: "1.0" +providers: + oidc-test: + name: Test OIDC Provider + description: OAuth2 provider backed by Keycloak for agent login e2e testing + provider: oauth2 + enabled: true + config: + client_id: thand-oidc + client_secret: thand-oidc-secret + auth_url: '${ .OIDC_ISSUER_URL + "/protocol/openid-connect/auth" }' + token_url: '${ .OIDC_ISSUER_URL_INTERNAL + "/protocol/openid-connect/token" }' + scopes: + - openid + - profile + - email + redirect_url: '${ .THAND_SERVER_URL + "/api/v1/auth/callback/oidc-test" }' diff --git a/test/integration/frontend/testdata/agent-login/roles.yaml b/test/integration/frontend/testdata/agent-login/roles.yaml new file mode 100644 index 00000000..2f5e4586 --- /dev/null +++ b/test/integration/frontend/testdata/agent-login/roles.yaml @@ -0,0 +1,2 @@ +version: "1.0" +roles: {} diff --git a/test/integration/frontend/testdata/agent-login/workflow.yaml b/test/integration/frontend/testdata/agent-login/workflow.yaml new file mode 100644 index 00000000..a198c9bb --- /dev/null +++ b/test/integration/frontend/testdata/agent-login/workflow.yaml @@ -0,0 +1,2 @@ +version: "1.0" +workflows: {} From 5c093ce85d5f168a802ba8bbcbddcf30e93f00c3 Mon Sep 17 00:00:00 2001 From: Michael Weber Date: Mon, 20 Apr 2026 08:00:48 -0500 Subject: [PATCH 04/23] fix(config): guard config push-back sync behind real Thand service --- internal/config/model.go | 17 ++++----- internal/config/model_test.go | 66 +++++++++++++++++++++++++++++++++++ internal/config/sync.go | 10 ++++++ internal/config/sync_test.go | 39 +++++++++++++++++++++ 4 files changed, 124 insertions(+), 8 deletions(-) create mode 100644 internal/config/model_test.go diff --git a/internal/config/model.go b/internal/config/model.go index 4d3f9b46..d009e1a3 100644 --- a/internal/config/model.go +++ b/internal/config/model.go @@ -286,7 +286,7 @@ func (c *Config) GetThandServerUrl() string { } func (c *Config) DiscoverThandServerApiUrl() string { - return c.discoverServerApiUrl(c.Thand.Endpoint, &model.ReferenceableAuthenticationPolicy{ + return c.discoverServerApiUrl("Thand server", c.Thand.Endpoint, &model.ReferenceableAuthenticationPolicy{ AuthenticationPolicy: &model.AuthenticationPolicy{ Bearer: &model.BearerAuthenticationPolicy{ Token: c.Thand.ApiKey, @@ -296,11 +296,12 @@ func (c *Config) DiscoverThandServerApiUrl() string { } func (c *Config) DiscoverLoginServerApiUrl(loginServer string) string { - return c.discoverServerApiUrl(loginServer, nil) + return c.discoverServerApiUrl("login server", loginServer, nil) } func (c *Config) discoverServerApiUrl( - loginServer string, + serviceName string, + serverURL string, auth *model.ReferenceableAuthenticationPolicy, ) string { @@ -308,8 +309,8 @@ func (c *Config) discoverServerApiUrl( // /.well-known/api-configuration endpoint // to get the base param which is our api endpoint using resty - discoveryCheckUrl := fmt.Sprintf("%s/.well-known/api-configuration", loginServer) - defaultUrl := fmt.Sprintf("%s/api/v1", loginServer) + discoveryCheckUrl := fmt.Sprintf("%s/.well-known/api-configuration", serverURL) + defaultUrl := fmt.Sprintf("%s/api/v1", serverURL) resp, err := common.InvokeHttpRequest(&model.HTTPArguments{ Endpoint: &model.Endpoint{ @@ -340,12 +341,12 @@ func (c *Config) discoverServerApiUrl( } if len(discoveryCheckResponse.BaseUrl) > 0 { - logrus.Debugf("Discovered login server base URL: %s", discoveryCheckResponse.BaseUrl) - loginServer = strings.TrimSuffix(discoveryCheckResponse.BaseUrl, "/") + logrus.Debugf("Discovered %s base URL: %s", serviceName, discoveryCheckResponse.BaseUrl) + serverURL = strings.TrimSuffix(discoveryCheckResponse.BaseUrl, "/") } trimPath := strings.TrimSuffix(strings.TrimPrefix(discoveryCheckResponse.ApiBasePath, "/"), "/") - return fmt.Sprintf("%s/%s", loginServer, trimPath) + return fmt.Sprintf("%s/%s", serverURL, trimPath) } func (c *Config) GetLoginServerHostname() string { diff --git a/internal/config/model_test.go b/internal/config/model_test.go new file mode 100644 index 00000000..eb59440c --- /dev/null +++ b/internal/config/model_test.go @@ -0,0 +1,66 @@ +package config + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + logrustest "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDiscoverLoginServerApiUrl_LogsLoginServerDiscovery(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/.well-known/api-configuration", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{"baseUrl":"https://auth.example.com","apiBasePath":"/api/v1"}`)) + require.NoError(t, err) + })) + t.Cleanup(server.Close) + + hook := logrustest.NewGlobal() + defer hook.Reset() + + oldLevel := logrus.GetLevel() + logrus.SetLevel(logrus.DebugLevel) + defer logrus.SetLevel(oldLevel) + + config := &Config{} + + apiURL := config.DiscoverLoginServerApiUrl(server.URL) + + require.Equal(t, "https://auth.example.com/api/v1", apiURL) + lastEntry := hook.LastEntry() + require.NotNil(t, lastEntry) + assert.Equal(t, "Discovered login server base URL: https://auth.example.com", lastEntry.Message) +} + +func TestDiscoverThandServerApiUrl_LogsThandServerDiscovery(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/.well-known/api-configuration", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{"baseUrl":"https://config.example.com","apiBasePath":"/api/v1"}`)) + require.NoError(t, err) + })) + t.Cleanup(server.Close) + + hook := logrustest.NewGlobal() + defer hook.Reset() + + oldLevel := logrus.GetLevel() + logrus.SetLevel(logrus.DebugLevel) + defer logrus.SetLevel(oldLevel) + + config := &Config{} + config.Thand.Endpoint = server.URL + config.Thand.ApiKey = "test-token" + + apiURL := config.DiscoverThandServerApiUrl() + + require.Equal(t, "https://config.example.com/api/v1", apiURL) + lastEntry := hook.LastEntry() + require.NotNil(t, lastEntry) + assert.Equal(t, "Discovered Thand server base URL: https://config.example.com", lastEntry.Message) +} diff --git a/internal/config/sync.go b/internal/config/sync.go index 51f8971b..f91e218c 100644 --- a/internal/config/sync.go +++ b/internal/config/sync.go @@ -98,6 +98,16 @@ func (c *Config) MergeConfiguration(config *RegistrationResponse) error { return err } + if !c.HasThandService() { + logrus.Debugln("Skipping configuration push-back sync because no Thand service is configured") + return nil + } + + if !c.Thand.Sync { + logrus.Debugln("Skipping configuration push-back sync because thand.sync is disabled") + return nil + } + go func() { logrus.Debugln("Sending configuration updates back to server") diff --git a/internal/config/sync_test.go b/internal/config/sync_test.go index 07c4f0ce..5d7b9a98 100644 --- a/internal/config/sync_test.go +++ b/internal/config/sync_test.go @@ -87,6 +87,7 @@ func newSyncTestConfig( Thand: models.ThandConfig{ Endpoint: endpoint, ApiKey: "test-api-key", + Sync: true, }, } @@ -168,6 +169,16 @@ func waitForPatch(ch <-chan syncPatchCall, timeout time.Duration) (syncPatchCall } } +func assertNoPatch(t *testing.T, ch <-chan syncPatchCall, timeout time.Duration) { + t.Helper() + + select { + case call := <-ch: + t.Fatalf("expected no outgoing PATCH call, got %s %s", call.Method, call.URL) + case <-time.After(timeout): + } +} + // --------------------------------------------------------------------------- // Tests for MergeConfiguration — incoming config merging // --------------------------------------------------------------------------- @@ -640,6 +651,34 @@ func TestMergeConfiguration_OutgoingPatch_URLContainsSync(t *testing.T) { assert.Contains(t, call.URL, "/sync") } +func TestMergeConfiguration_NoOutgoingPatchWithoutThandService(t *testing.T) { + server, patchCh := newSyncTestServer(t) + + config := newSyncTestConfig(t, nil, nil, nil, server.URL) + config.Thand.ApiKey = "" + + reg := makeRegistrationResponse(nil, nil, nil) + + err := config.MergeConfiguration(reg) + require.NoError(t, err) + + assertNoPatch(t, patchCh, 300*time.Millisecond) +} + +func TestMergeConfiguration_NoOutgoingPatchWhenThandSyncDisabled(t *testing.T) { + server, patchCh := newSyncTestServer(t) + + config := newSyncTestConfig(t, nil, nil, nil, server.URL) + config.Thand.Sync = false + + reg := makeRegistrationResponse(nil, nil, nil) + + err := config.MergeConfiguration(reg) + require.NoError(t, err) + + assertNoPatch(t, patchCh, 300*time.Millisecond) +} + // --------------------------------------------------------------------------- // Tests for applyPatch // --------------------------------------------------------------------------- From 86d65f81896ef082bfe480c0b9c9367c19e78726 Mon Sep 17 00:00:00 2001 From: Michael Weber Date: Thu, 23 Apr 2026 21:15:30 -0500 Subject: [PATCH 05/23] temporal: keep shared device registry queues unversioned --- internal/config/services/temporal/main.go | 60 +++++++++++++------ .../services/temporal/versioning_test.go | 49 +++++++++++++++ internal/models/device_registry_queue.go | 3 + internal/testing/temporaltest/temporaltest.go | 19 ++++++ .../tasks/providers/thand/versioning.go | 24 ++++++++ .../tasks/providers/thand/versioning_test.go | 45 ++++++++++++++ .../providers/aws/aws_sync_workflow_test.go | 2 + 7 files changed, 185 insertions(+), 17 deletions(-) create mode 100644 internal/config/services/temporal/versioning_test.go create mode 100644 internal/models/device_registry_queue.go create mode 100644 internal/testing/temporaltest/temporaltest.go create mode 100644 internal/workflows/tasks/providers/thand/versioning.go create mode 100644 internal/workflows/tasks/providers/thand/versioning_test.go diff --git a/internal/config/services/temporal/main.go b/internal/config/services/temporal/main.go index 5dc2a38a..5287cfed 100644 --- a/internal/config/services/temporal/main.go +++ b/internal/config/services/temporal/main.go @@ -61,6 +61,40 @@ func NewTemporalClient( } } +func (a *TemporalClient) shouldUseVersioning(identity string) bool { + if a.config.DisableVersioning { + return false + } + // Keep the shared device-registry queue unversioned. Its singleton + // workflows are internal infrastructure and are reconstructed from server + // startup publication plus agent route refreshes, so the operational + // versioned deployment path is unnecessary here and has proven brittle. + return identity != models.TemporalDeviceRegistryTaskQueue +} + +func (a *TemporalClient) workerOptionsForIdentity(identity string, buildID string) worker.Options { + workerOptions := worker.Options{ + Identity: a.GetIdentity(), + MaxConcurrentActivityTaskPollers: 5, + } + + if !a.shouldUseVersioning(identity) { + return workerOptions + } + + workerOptions.DeploymentOptions = worker.DeploymentOptions{ + UseVersioning: true, + Version: worker.WorkerDeploymentVersion{ + DeploymentName: sdkConstants.TemporalDeploymentName, + BuildID: buildID, + }, + // Default workflows to Pinned behavior + DefaultVersioningBehavior: workflow.VersioningBehaviorPinned, + } + + return workerOptions +} + func (a *TemporalClient) Initialize() error { if len(a.identities) == 0 { @@ -108,27 +142,11 @@ func (a *TemporalClient) Initialize() error { // Get agent version for Worker Build ID buildID := common.GetBuildIdentifier() - - workerOptions := worker.Options{ - Identity: a.GetIdentity(), - MaxConcurrentActivityTaskPollers: 5, - } - if !a.config.DisableVersioning { logrus.WithFields(logrus.Fields{ "BuildID": buildID, "DeploymentName": sdkConstants.TemporalDeploymentName, }).Info("Configuring Worker with versioning") - - workerOptions.DeploymentOptions = worker.DeploymentOptions{ - UseVersioning: true, - Version: worker.WorkerDeploymentVersion{ - DeploymentName: sdkConstants.TemporalDeploymentName, - BuildID: buildID, - }, - // Default workflows to Pinned behavior - DefaultVersioningBehavior: workflow.VersioningBehaviorPinned, - } } // Create and start a worker for each identity (task queue) @@ -140,7 +158,15 @@ func (a *TemporalClient) Initialize() error { return nil } + hasVersionedWorkers := false for _, identity := range a.identities { + workerOptions := a.workerOptionsForIdentity(identity, buildID) + if workerOptions.DeploymentOptions.UseVersioning { + hasVersionedWorkers = true + } else if !a.config.DisableVersioning && identity == models.TemporalDeviceRegistryTaskQueue { + logrus.WithField("taskQueue", identity).Info("Starting Temporal worker without versioning for shared device registry queue") + } + newWorker := worker.New( temporalClient, identity, @@ -169,7 +195,7 @@ func (a *TemporalClient) Initialize() error { // If versioning is enabled, confirm our deployment version is registered // on the Temporal server before allowing workflow submissions via GetClient(). - if a.config.DisableVersioning { + if a.config.DisableVersioning || !hasVersionedWorkers { a.markReady() } else { go a.awaitVersionRegistration(buildID) diff --git a/internal/config/services/temporal/versioning_test.go b/internal/config/services/temporal/versioning_test.go new file mode 100644 index 00000000..e6b5741e --- /dev/null +++ b/internal/config/services/temporal/versioning_test.go @@ -0,0 +1,49 @@ +package temporal + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/thand-io/agent/internal/models" +) + +func TestShouldUseVersioningForIdentity(t *testing.T) { + t.Parallel() + + client := NewTemporalClient( + &models.TemporalConfig{ + Host: "localhost", + Port: 7233, + Namespace: "default", + DisableVersioning: false, + }, + nil, + "thand_local_alpha_server01", + models.TemporalDeviceRegistryTaskQueue, + ) + + assert.True(t, client.shouldUseVersioning("thand_local_alpha_server01")) + assert.False(t, client.shouldUseVersioning(models.TemporalDeviceRegistryTaskQueue)) +} + +func TestWorkerOptionsForIdentityKeepsRegistryQueueUnversioned(t *testing.T) { + t.Parallel() + + client := NewTemporalClient( + &models.TemporalConfig{ + Host: "localhost", + Port: 7233, + Namespace: "default", + DisableVersioning: false, + }, + nil, + "thand_local_alpha_server01", + models.TemporalDeviceRegistryTaskQueue, + ) + + operational := client.workerOptionsForIdentity("thand_local_alpha_server01", "build-123") + registry := client.workerOptionsForIdentity(models.TemporalDeviceRegistryTaskQueue, "build-123") + + assert.True(t, operational.DeploymentOptions.UseVersioning) + assert.False(t, registry.DeploymentOptions.UseVersioning) +} diff --git a/internal/models/device_registry_queue.go b/internal/models/device_registry_queue.go new file mode 100644 index 00000000..2bb0326a --- /dev/null +++ b/internal/models/device_registry_queue.go @@ -0,0 +1,3 @@ +package models + +const TemporalDeviceRegistryTaskQueue = "thand_device_registry" diff --git a/internal/testing/temporaltest/temporaltest.go b/internal/testing/temporaltest/temporaltest.go new file mode 100644 index 00000000..ffc25525 --- /dev/null +++ b/internal/testing/temporaltest/temporaltest.go @@ -0,0 +1,19 @@ +package temporaltest + +import ( + "sync" + + "go.temporal.io/sdk/worker" +) + +var binaryChecksumOnce sync.Once + +// SeedBinaryChecksum seeds the Temporal SDK's process-global checksum cache once +// per test binary so testsuite activity workers do not hash the full test +// binary the first time a workflow executes an activity. This is safe for +// t.Parallel because all callers install the same deterministic checksum value. +func SeedBinaryChecksum() { + binaryChecksumOnce.Do(func() { + worker.SetBinaryChecksum("test-build-id") + }) +} diff --git a/internal/workflows/tasks/providers/thand/versioning.go b/internal/workflows/tasks/providers/thand/versioning.go new file mode 100644 index 00000000..508d816e --- /dev/null +++ b/internal/workflows/tasks/providers/thand/versioning.go @@ -0,0 +1,24 @@ +package thand + +import ( + "strings" + + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" +) + +// childWorkflowOptionsForTaskQueue preserves the current workflow build when a +// child is dispatched onto a different task queue. Without this, deployment- +// versioned workers can fail to pick up cross-queue child workflows. +func childWorkflowOptionsForTaskQueue( + currentQueue string, + targetQueue string, + opts workflow.ChildWorkflowOptions, +) workflow.ChildWorkflowOptions { + currentQueue = strings.TrimSpace(currentQueue) + targetQueue = strings.TrimSpace(targetQueue) + if targetQueue != "" && currentQueue != "" && targetQueue != currentQueue { + opts.VersioningIntent = temporal.VersioningIntentInheritBuildID + } + return opts +} diff --git a/internal/workflows/tasks/providers/thand/versioning_test.go b/internal/workflows/tasks/providers/thand/versioning_test.go new file mode 100644 index 00000000..65158deb --- /dev/null +++ b/internal/workflows/tasks/providers/thand/versioning_test.go @@ -0,0 +1,45 @@ +package thand + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" +) + +func TestChildWorkflowOptionsForTaskQueueInheritsBuildAcrossQueues(t *testing.T) { + t.Parallel() + + opts := childWorkflowOptionsForTaskQueue( + "thand_local_server_alpha", + "thand_local_workstation_alpha", + workflow.ChildWorkflowOptions{TaskQueue: "thand_local_workstation_alpha"}, + ) + + assert.Equal(t, temporal.VersioningIntentInheritBuildID, opts.VersioningIntent) +} + +func TestChildWorkflowOptionsForTaskQueueLeavesSameQueueUnspecified(t *testing.T) { + t.Parallel() + + opts := childWorkflowOptionsForTaskQueue( + "thand_local_server_alpha", + "thand_local_server_alpha", + workflow.ChildWorkflowOptions{TaskQueue: "thand_local_server_alpha"}, + ) + + assert.Equal(t, temporal.VersioningIntentUnspecified, opts.VersioningIntent) +} + +func TestChildWorkflowOptionsForTaskQueueLeavesUnknownQueueUnspecified(t *testing.T) { + t.Parallel() + + opts := childWorkflowOptionsForTaskQueue( + "", + "thand_local_workstation_alpha", + workflow.ChildWorkflowOptions{TaskQueue: "thand_local_workstation_alpha"}, + ) + + assert.Equal(t, temporal.VersioningIntentUnspecified, opts.VersioningIntent) +} diff --git a/test/functional/providers/aws/aws_sync_workflow_test.go b/test/functional/providers/aws/aws_sync_workflow_test.go index 3a17fe1a..81a9ab9c 100644 --- a/test/functional/providers/aws/aws_sync_workflow_test.go +++ b/test/functional/providers/aws/aws_sync_workflow_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thand-io/agent/internal/models" + "github.com/thand-io/agent/internal/testing/temporaltest" "go.temporal.io/sdk/activity" "go.temporal.io/sdk/testsuite" "go.temporal.io/sdk/workflow" @@ -250,6 +251,7 @@ func executeSyncWorkflow( ) *testsuite.TestWorkflowEnvironment { t.Helper() + temporaltest.SeedBinaryChecksum() suite := &testsuite.WorkflowTestSuite{} env := suite.NewTestWorkflowEnvironment() From 66846eafc78bebdbbd67cb196b7a95bd37632731 Mon Sep 17 00:00:00 2001 From: Michael Weber Date: Wed, 22 Apr 2026 18:12:47 -0500 Subject: [PATCH 06/23] feat(devices): add canonical device_id plumbing and route registry --- cmd/cli/README.md | 10 + cmd/cli/config_device_id.go | 24 ++ cmd/cli/config_device_id_test.go | 46 +++ cmd/cli/main.go | 17 + docs/api/agent/configuration.md | 16 +- docs/internal/adr-device-routing-phase-1.md | 72 +++++ docs/internal/device-model.md | 158 ++++++++++ internal/common/client.go | 15 +- internal/common/device_id_default.go | 10 + internal/common/device_id_dev.go | 25 ++ internal/common/device_id_dev_test.go | 22 ++ internal/common/device_id_test.go | 24 ++ internal/config/config.go | 36 ++- internal/config/device_bootstrap.go | 207 ++++++++++++ internal/config/device_bootstrap_test.go | 138 ++++++++ internal/config/device_definition_registry.go | 62 ++++ .../config/device_definition_registry_test.go | 91 ++++++ internal/config/device_registry.go | 298 ++++++++++++++++++ internal/config/device_registry_test.go | 139 ++++++++ internal/config/device_route_registry.go | 101 ++++++ internal/config/device_route_registry_test.go | 76 +++++ internal/config/devices.go | 111 +++++++ internal/config/devices_test.go | 117 +++++++ internal/config/model.go | 20 ++ internal/config/providers.go | 239 +++++++------- internal/config/services.go | 18 +- internal/config/services/client.go | 5 + internal/config/temporal.go | 36 ++- internal/config/temporal_activities.go | 63 ++++ internal/config/temporal_workers.go | 99 ++++++ internal/daemon/register.go | 10 +- internal/daemon/register_test.go | 149 +++++++++ internal/models/config.go | 4 + internal/models/device.go | 48 +++ internal/models/device_local_elevation.go | 75 +++++ internal/models/device_registry_queue.go | 3 - internal/workflows/manager/workflows.go | 2 +- 37 files changed, 2446 insertions(+), 140 deletions(-) create mode 100644 cmd/cli/config_device_id.go create mode 100644 cmd/cli/config_device_id_test.go create mode 100644 docs/internal/adr-device-routing-phase-1.md create mode 100644 docs/internal/device-model.md create mode 100644 internal/common/device_id_default.go create mode 100644 internal/common/device_id_dev.go create mode 100644 internal/common/device_id_dev_test.go create mode 100644 internal/common/device_id_test.go create mode 100644 internal/config/device_bootstrap.go create mode 100644 internal/config/device_bootstrap_test.go create mode 100644 internal/config/device_definition_registry.go create mode 100644 internal/config/device_definition_registry_test.go create mode 100644 internal/config/device_registry.go create mode 100644 internal/config/device_registry_test.go create mode 100644 internal/config/device_route_registry.go create mode 100644 internal/config/device_route_registry_test.go create mode 100644 internal/config/devices.go create mode 100644 internal/config/devices_test.go create mode 100644 internal/config/temporal_workers.go create mode 100644 internal/daemon/register_test.go create mode 100644 internal/models/device.go create mode 100644 internal/models/device_local_elevation.go delete mode 100644 internal/models/device_registry_queue.go diff --git a/cmd/cli/README.md b/cmd/cli/README.md index 76eef638..11176c61 100644 --- a/cmd/cli/README.md +++ b/cmd/cli/README.md @@ -118,6 +118,16 @@ Shows current configuration including: - Login endpoint - Logging level +#### `thand config device-id` +Print the canonical device ID for the current machine. + +**Usage:** +```bash +thand config device-id +``` + +This prints the effective `device_id` only, with no extra label text, so it can be copied directly into device configuration. + #### `thand roles` List available roles from the remote login server. diff --git a/cmd/cli/config_device_id.go b/cmd/cli/config_device_id.go new file mode 100644 index 00000000..708eed1e --- /dev/null +++ b/cmd/cli/config_device_id.go @@ -0,0 +1,24 @@ +package cli + +import ( + "fmt" + + "github.com/spf13/cobra" + "github.com/thand-io/agent/internal/common" +) + +var configDeviceIDCmd = &cobra.Command{ + Use: "device-id", + Short: "Print the effective device ID for this machine", + Args: cobra.NoArgs, + SilenceUsage: true, + SilenceErrors: true, + RunE: func(cmd *cobra.Command, _ []string) error { + _, err := fmt.Fprintln(cmd.OutOrStdout(), common.GetDeviceID().String()) + return err + }, +} + +func init() { + configCmd.AddCommand(configDeviceIDCmd) +} diff --git a/cmd/cli/config_device_id_test.go b/cmd/cli/config_device_id_test.go new file mode 100644 index 00000000..ac73489e --- /dev/null +++ b/cmd/cli/config_device_id_test.go @@ -0,0 +1,46 @@ +package cli + +import ( + "bytes" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/thand-io/agent/internal/common" +) + +func TestConfigDeviceIDCommandPrintsEffectiveDeviceID(t *testing.T) { + cmd := &cobra.Command{} + var out bytes.Buffer + cmd.SetOut(&out) + + if err := configDeviceIDCmd.RunE(cmd, nil); err != nil { + t.Fatalf("RunE returned error: %v", err) + } + + got := strings.TrimSpace(out.String()) + want := common.GetDeviceID().String() + + if got != want { + t.Fatalf("printed device ID = %q, want %q", got, want) + } +} + +func TestConfigDeviceIDCommandWritesOnlyToStdout(t *testing.T) { + cmd := &cobra.Command{} + var out bytes.Buffer + var stderr bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&stderr) + + if err := configDeviceIDCmd.RunE(cmd, nil); err != nil { + t.Fatalf("RunE returned error: %v", err) + } + + if stderr.Len() != 0 { + t.Fatalf("stderr = %q, want empty", stderr.String()) + } + if strings.TrimSpace(out.String()) == "" { + t.Fatal("stdout was empty") + } +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 16f17934..40ad8053 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" "os" @@ -120,6 +121,17 @@ func preRunConfigE(cmd *cobra.Command, mode config.Mode) error { case config.ModeAgent: + // Materialize provider, role, and workflow definitions from the current + // configured sources (path/url/vault/defaults) before provider + // initialization. The agent bootstrap path currently only returns + // services/device state from the login server, so InitializeProviders + // still depends on this local definition load. + err = cfg.ReloadConfig() + if err != nil { + logrus.WithError(err).Errorln("Failed to load local agent configuration") + return fmt.Errorf("failed to load local agent configuration: %w", err) + } + // Initialize providers err = cfg.InitializeProviders() @@ -128,6 +140,11 @@ func preRunConfigE(cmd *cobra.Command, mode config.Mode) error { return err } + isDefaultLoginEndpoint := cfg.GetLoginServerUrl() == common.DefaultLoginServerEndpoint + if cfg.HasLoginServer() && !isDefaultLoginEndpoint { + go cfg.RunDeviceBootstrap(context.Background()) + } + case config.ModeServer: // Load local config first before registering with the thand server. diff --git a/docs/api/agent/configuration.md b/docs/api/agent/configuration.md index 323aa604..fdacfa92 100644 --- a/docs/api/agent/configuration.md +++ b/docs/api/agent/configuration.md @@ -56,7 +56,7 @@ Currently a stub endpoint for future pre-flight validation. ## Register Agent -Register an agent with the server. +Bootstrap agent configuration from the server. **POST** `/register` @@ -64,13 +64,18 @@ Register an agent with the server. - Server Mode Only +`/register` is a configuration/bootstrap handshake only. It returns server-managed config snapshots, but it does not publish live device routes. Running agents publish their live route directly to Temporal after bootstrap succeeds. + ### Request Body ```json { + "mode": "agent", + "identifier": "11111111-2222-3333-4444-555555555555", "environment": { - "name": "production", - "description": "Production environment configuration" + "name": "workstation-alpha", + "hostname": "workstation-alpha.example.test", + "platform": "local" } } ``` @@ -129,6 +134,11 @@ The registration response contains the complete configuration for the agent, inc If the upstream server has a newer version of the configuration, the agent will update its local configuration to match the server's state. This ensures that policies and configurations are consistent across the infrastructure. +For device-local workflows: + +- servers publish device definitions/policy to the shared device-definition registry +- agents publish live `device_id -> task_queue` route state to the shared device-route registry + ## Post-flight Check Validate configuration after registration. diff --git a/docs/internal/adr-device-routing-phase-1.md b/docs/internal/adr-device-routing-phase-1.md new file mode 100644 index 00000000..0036b705 --- /dev/null +++ b/docs/internal/adr-device-routing-phase-1.md @@ -0,0 +1,72 @@ +--- +layout: default +title: ADR: Device Routing Phase 1 +parent: Internal +nav_order: 2 +--- + +# ADR: Device Routing Phase 1 + +## Status + +Accepted. + +## Context + +Device-local workflows need a way to target a specific machine. + +At the same time, the project is not yet ready to introduce full cryptographic device identity or a dedicated device control plane. + +## Decision + +Phase 1 uses: + +- first-class device definitions on the server +- canonical `device_id` as the device-matching key +- `/register` as bootstrap/config sync only +- live route tracking from agent-published device-route state +- shared Temporal device-definition and device-route registries on `thand_device_registry` +- periodic route refresh from running agents +- fresh-route checks for device-targeted workflow dispatch + +It explicitly does not use: + +- provider tenants as device identifiers +- indefinite waiting for device-targeted authorize steps + +## Rationale + +This choice gives us the smallest useful device substrate that: + +- removes stale static routing +- supports reconnect-aware device execution +- keeps the design generic for future device-local workflows +- leaves room for a later secure identity redesign + +## Alternatives Considered + +### Model devices as provider tenants + +Rejected because tenants are provider-scoped account concepts, not machine execution concepts. + +### Require immediate online presence with no waiting + +Rejected because short outages and restarts are normal. A bounded wait is a better operator experience for authorize, while revoke can retry for reconciliation. + +### Solve cryptographic device identity first + +Rejected for phase 1 because it would block useful device plumbing behind a larger security redesign. + +## Consequences + +Positive consequences: + +- simpler config +- more honest routing model +- reusable device-targeted execution layer + +Negative consequences: + +- the current generated `device_id` is still a client-presented identity and not yet a cryptographically enrolled device credential +- device config and routing now depend on internal shared Temporal registries that need future hardening and operational polish +- later phases will need migration work for stronger enrollment diff --git a/docs/internal/device-model.md b/docs/internal/device-model.md new file mode 100644 index 00000000..ac4e823b --- /dev/null +++ b/docs/internal/device-model.md @@ -0,0 +1,158 @@ +--- +layout: default +title: Device Model +parent: Internal +nav_order: 1 +--- + +# Device Model + +This document describes the long-term device architecture for thand, the first phase now implemented in code, and the work intentionally deferred. + +## Why Devices Exist + +Local execution targets such as laptops, desktops, and servers have different lifecycle and security requirements than: + +- users, which represent human principals +- providers, which represent integrations such as AWS or JumpCloud +- tenants, which represent provider-scoped accounts or resource containers + +A device is therefore modeled as a first-class server-managed object. Devices let the server answer questions like: + +- which machine should receive a device-local workflow +- which device-local policy applies to that machine +- whether the machine is currently connected and routable + +Operators can print the local machine's current device ID with `thand config device-id`. +Non-production builds may override the generated value for deterministic testing, but production binaries always use the generated machine-derived `device_id`. + +## What A Device Is + +In the current model, a device has: + +- a stable device ID +- human-readable metadata such as `name` and `description` + +Runtime connection state is tracked separately from static device policy. That runtime state currently includes: + +- `task_queue` +- `last_seen_at` +- derived freshness / connected status + +## Why Devices Are Not Tenants + +Provider tenants and devices look superficially similar because both can affect routing and authorization scope, but they solve different problems. + +Tenants are provider-scoped. A tenant says which account or org inside a provider a request applies to. Devices are execution-scoped. A device says which machine should run a workflow or local action. + +Using tenants for devices would overload provider semantics with machine lifecycle concerns such as: + +- agent registration +- live route freshness +- local reconciliation after reconnect +- future privileged-helper transport + +That coupling would make both models harder to reason about, so devices remain separate. + +## Target Architecture + +The intended architecture is: + +1. The server owns device definitions and device policy. +2. An agent represents one device, running as a system-level service rather than a per-user helper. +3. `/register` bootstraps config only; running agents publish live route state directly to Temporal. +4. Device-targeted workflows route through that live route only. + +Today the canonical `device_id` is machine-derived. Longer term, device registration should use a stronger enrolled identity, but keep the same `device_id` abstraction boundary. + +## Phase 1: What Is Implemented Now + +Phase 1 establishes the basic device substrate without yet solving strong device identity. + +Implemented now: + +- first-class `Device` definitions in config +- live device connection state tracked in the shared Temporal device-route registry +- shared device definitions tracked in a Temporal device-definition registry +- periodic device registration refresh from the agent +- route freshness checks using `last_seen_at` +- device-targeted provider child workflows using a fresh live route +- bounded waiting for authorize when a device is temporarily offline +- retrying revoke reconciliation when the device is offline + +Not yet implemented: + +- cryptographic device enrollment or proof-of-possession identity behind the existing `device_id` abstraction +- a dedicated control-plane service for device config +- device discovery UI or richer device selection UX +- a privileged local helper split from the main agent + +Phase 1 routes only from fresh live registration state. + +## Current Routing Model + +Today, routing works like this: + +1. An agent registers with the server. +2. The server returns bootstrap/config data to the agent. +3. The agent publishes its current `task_queue` and `last_seen_at` to the shared Temporal device-route registry on `thand_device_registry`. +4. Servers publish configured device definitions to the shared Temporal device-definition registry on `thand_device_registry`. +5. Device-targeted workflows query shared device policy during execution planning and ask for a fresh route before dispatch. +6. If the route is missing or stale, authorize waits for a bounded window and revoke keeps retrying for reconciliation. + +This gives us a cleaner failure model: + +- authorize should not succeed much later than requested +- revoke should converge once the device reconnects +- timed local enforcement should not depend solely on centralized connectivity + +## Consequences of the Current Design + +The current phase-1 design has a few important consequences: + +- devices are now a generic execution substrate, not a sudo-only feature +- routing depends on liveness, not static config +- agents are treated as per-device services, not per-user services +- the machine-derived `device_id` is now the single routing identity used across registration, planning, and dispatch + +## Known Shortcomings + +The biggest gap is still device identity hardening. + +Today the server matches a connecting agent to a device through the generated `device_id`. That is enough for phase 1 plumbing and local development, but it is not strong enough for a final design because it is still based on client-presented identity. + +Other gaps: + +- no dedicated control-plane API for device configuration +- no secure enrollment story yet +- shared device registries are still internal Temporal workflows rather than a broader device control-plane service +- no independent privileged helper transport yet on Linux or Windows +- no explicit multi-agent-per-device design, because the current assumption is one system agent per device + +## Future Phases + +Future work should cover at least: + +### Strong device identity + +- enrolled device credentials +- challenge / proof-of-possession registration +- authenticated binding between device record and live route + +### Dedicated device config distribution + +- server-managed per-device policy delivery +- explicit control-plane lifecycle for devices +- eventual separation between interactive user login and device bootstrap + +### Privileged local helper + +- OS-native trust checks between the unprivileged agent, broker, and notifier +- narrow local lease/enforcer contract with persisted expiry and restart reconciliation +- future Linux and Windows helpers that match the same broker client abstraction + +### Better UX + +- device discovery APIs +- device picker UX +- clearer offline / reconnect status in local-device workflows diff --git a/internal/common/client.go b/internal/common/client.go index 172fe699..61750d67 100644 --- a/internal/common/client.go +++ b/internal/common/client.go @@ -7,19 +7,22 @@ import ( "github.com/google/uuid" ) -// GetClientIdentifier returns a UUID that uniquely identifies this system. -// It uses the machine's hardware ID to generate a consistent, system-specific UUID. -func GetClientIdentifier() uuid.UUID { - +func getMachineDerivedDeviceID() uuid.UUID { // TODO(hugh): Check if the thand.io config exists and use that for an identifier. id, err := machineid.ID() if err != nil { - // Fallback to a random ephemeral UUID if machine ID cannot be obtained + // Fallback to a random ephemeral UUID if machine ID cannot be obtained. return uuid.New() } - // Hash the machine ID and convert to UUID format + // Hash the machine ID and convert to UUID format. hash := sha256.Sum256([]byte(id)) return uuid.UUID(hash[:16]) } + +// GetClientIdentifier returns the stable machine-derived identifier used by +// legacy call sites. Device registration and routing should prefer GetDeviceID. +func GetClientIdentifier() uuid.UUID { + return GetDeviceID() +} diff --git a/internal/common/device_id_default.go b/internal/common/device_id_default.go new file mode 100644 index 00000000..b6858c3b --- /dev/null +++ b/internal/common/device_id_default.go @@ -0,0 +1,10 @@ +//go:build !thand_dev + +package common + +import "github.com/google/uuid" + +// GetDeviceID returns the effective device identity for this machine. +func GetDeviceID() uuid.UUID { + return getMachineDerivedDeviceID() +} diff --git a/internal/common/device_id_dev.go b/internal/common/device_id_dev.go new file mode 100644 index 00000000..ea008b90 --- /dev/null +++ b/internal/common/device_id_dev.go @@ -0,0 +1,25 @@ +//go:build thand_dev + +package common + +import ( + "os" + "strings" + + "github.com/google/uuid" +) + +const deviceIDOverrideEnvVar = "THAND_DEV_DEVICE_ID_OVERRIDE" + +// GetDeviceID returns the effective device identity for this machine. +// Dev-tagged builds may override the machine-derived ID for deterministic tests. +func GetDeviceID() uuid.UUID { + override := strings.TrimSpace(os.Getenv(deviceIDOverrideEnvVar)) + if override != "" { + if parsed, err := uuid.Parse(override); err == nil { + return parsed + } + } + + return getMachineDerivedDeviceID() +} diff --git a/internal/common/device_id_dev_test.go b/internal/common/device_id_dev_test.go new file mode 100644 index 00000000..ce513620 --- /dev/null +++ b/internal/common/device_id_dev_test.go @@ -0,0 +1,22 @@ +//go:build thand_dev + +package common + +import "testing" + +func TestGetDeviceIDHonorsDevOverride(t *testing.T) { + t.Setenv(deviceIDOverrideEnvVar, "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + + got := GetDeviceID().String() + want := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + if got != want { + t.Fatalf("GetDeviceID() = %q, want %q", got, want) + } +} + +func TestGetClientIdentifierMatchesDeviceID(t *testing.T) { + if got, want := GetClientIdentifier(), GetDeviceID(); got != want { + t.Fatalf("GetClientIdentifier() = %q, want %q", got, want) + } +} diff --git a/internal/common/device_id_test.go b/internal/common/device_id_test.go new file mode 100644 index 00000000..775fe08f --- /dev/null +++ b/internal/common/device_id_test.go @@ -0,0 +1,24 @@ +//go:build !thand_dev + +package common + +import ( + "testing" +) + +func TestGetDeviceIDIgnoresDevOverrideInProductionBuild(t *testing.T) { + t.Setenv("THAND_DEV_DEVICE_ID_OVERRIDE", "11111111-2222-3333-4444-555555555555") + + got := GetDeviceID() + want := getMachineDerivedDeviceID() + + if got != want { + t.Fatalf("GetDeviceID() = %q, want machine-derived %q", got, want) + } +} + +func TestGetClientIdentifierMatchesDeviceID(t *testing.T) { + if got, want := GetClientIdentifier(), GetDeviceID(); got != want { + t.Fatalf("GetClientIdentifier() = %q, want %q", got, want) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index dd3282d6..e994d922 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -603,7 +603,9 @@ func (c *Config) RegisterWithThandServer() error { }, } - registration, err := c.syncWithEndpoint(thandLoginUrl, authentication) + registration, err := c.syncWithEndpoint(thandLoginUrl, authentication, loginServerRegistrationOptions{ + applyServices: true, + }) if err != nil { return fmt.Errorf("failed to register with thand server: %w", err) @@ -623,17 +625,39 @@ func (c *Config) RegisterWithThandServer() error { } +type loginServerRegistrationOptions struct { + applyServices bool +} + func (c *Config) RegisterWithLoginServer(auth *model.ReferenceableAuthenticationPolicy) (*RegistrationResponse, error) { loginUrl := c.DiscoverLoginServerApiUrl( c.GetLoginServerUrl(), ) - return c.syncWithEndpoint(loginUrl, auth) + return c.syncWithEndpoint(loginUrl, auth, loginServerRegistrationOptions{ + applyServices: true, + }) + +} + +func (c *Config) RefreshLoginServerRegistration(auth *model.ReferenceableAuthenticationPolicy) (*RegistrationResponse, error) { + + loginUrl := c.DiscoverLoginServerApiUrl( + c.GetLoginServerUrl(), + ) + + return c.syncWithEndpoint(loginUrl, auth, loginServerRegistrationOptions{ + applyServices: false, + }) } -func (c *Config) syncWithEndpoint(loginUrl string, authentication *model.ReferenceableAuthenticationPolicy) (*RegistrationResponse, error) { +func (c *Config) syncWithEndpoint( + loginUrl string, + authentication *model.ReferenceableAuthenticationPolicy, + options loginServerRegistrationOptions, +) (*RegistrationResponse, error) { version, commit, _ := common.GetModuleBuildInfo() @@ -641,7 +665,7 @@ func (c *Config) syncWithEndpoint(loginUrl string, authentication *model.Referen Mode: c.GetMode(), Version: version, Commit: commit, - Identifier: common.GetClientIdentifier(), + Identifier: common.GetDeviceID(), Endpoint: c.GetLoginServerUrl(), Origin: c.GetLocalServerUrl(), }) @@ -678,7 +702,7 @@ func (c *Config) syncWithEndpoint(loginUrl string, authentication *model.Referen Environment: &c.Environment, Version: version, Commit: commit, - Identifier: common.GetClientIdentifier(), + Identifier: common.GetDeviceID(), Endpoint: c.GetLoginServerUrl(), Origin: c.GetLocalServerUrl(), }) @@ -757,7 +781,7 @@ func (c *Config) syncWithEndpoint(loginUrl string, authentication *model.Referen } } - if registrationResponse.Services != nil { + if options.applyServices && registrationResponse.Services != nil { // Setup temporal services if provided if registrationResponse.Services.Temporal != nil { diff --git a/internal/config/device_bootstrap.go b/internal/config/device_bootstrap.go new file mode 100644 index 00000000..7056ab75 --- /dev/null +++ b/internal/config/device_bootstrap.go @@ -0,0 +1,207 @@ +package config + +import ( + "context" + "fmt" + "time" + + "github.com/sirupsen/logrus" + "github.com/thand-io/agent/internal/common" + "github.com/thand-io/agent/internal/models" +) + +const ( + deviceBootstrapInitialBackoff = 2 * time.Second + deviceBootstrapMaxBackoff = 1 * time.Minute +) + +func (c *Config) BootstrapDeviceWithLoginServer() error { + if !c.IsAgent() { + return fmt.Errorf("device bootstrap is only valid in agent mode") + } + if !c.HasLoginServer() { + return fmt.Errorf("no login server endpoint configured") + } + + registration, err := c.RegisterWithLoginServer(nil) + if err != nil { + return err + } + + if err := c.applyRegistrationConfiguration(registration); err != nil { + return err + } + + environment := c.GetEnvironmentConfig() + logrus.WithFields(logrus.Fields{ + "device_id": common.GetDeviceID().String(), + "name": environment.Name, + "hostname": environment.Hostname, + "platform": environment.Platform, + "has_config": registration != nil, + }).Info("Bootstrapped agent configuration from login server") + + if err := c.EnsureProviderTemporalBindings(); err != nil { + return fmt.Errorf("ensuring provider temporal bindings: %w", err) + } + + if err := c.PublishCurrentAgentRoute(context.Background()); err != nil { + return fmt.Errorf("publishing current device route: %w", err) + } + + return nil +} + +func (c *Config) RefreshDeviceRegistrationWithLoginServer() error { + if !c.IsAgent() { + return fmt.Errorf("device refresh is only valid in agent mode") + } + if !c.HasLoginServer() { + return fmt.Errorf("no login server endpoint configured") + } + + registration, err := c.RefreshLoginServerRegistration(nil) + if err != nil { + return err + } + + if err := c.applyRegistrationConfiguration(registration); err != nil { + return err + } + + environment := c.GetEnvironmentConfig() + logrus.WithFields(logrus.Fields{ + "device_id": common.GetDeviceID().String(), + "name": environment.Name, + "hostname": environment.Hostname, + "platform": environment.Platform, + }).Debug("Refreshed device registration with login server") + + if err := c.PublishCurrentAgentRoute(context.Background()); err != nil { + return fmt.Errorf("publishing current device route: %w", err) + } + + return nil +} + +func (c *Config) RunDeviceBootstrap(ctx context.Context) { + backoff := deviceBootstrapInitialBackoff + bootstrapped := false + + for { + var err error + if bootstrapped { + err = c.RefreshDeviceRegistrationWithLoginServer() + } else { + err = c.BootstrapDeviceWithLoginServer() + } + if err == nil { + bootstrapped = true + backoff = deviceBootstrapInitialBackoff + + timer := time.NewTimer(deviceRouteRefreshInterval) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-timer.C: + } + continue + } + + logrus.WithError(err).WithField("retry_in", backoff).Warn("device bootstrap failed; retrying") + + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-timer.C: + } + + backoff *= 2 + if backoff > deviceBootstrapMaxBackoff { + backoff = deviceBootstrapMaxBackoff + } + } +} + +func (c *Config) applyRegistrationConfiguration(registration *RegistrationResponse) error { + if registration == nil { + return nil + } + + if registration.Roles == nil && + registration.Workflows == nil && + registration.Providers == nil { + return nil + } + + beforeGeneration := c.getConfigGeneration() + if err := c.MergeConfiguration(registration); err != nil { + return fmt.Errorf("merging registration configuration: %w", err) + } + + if c.getConfigGeneration() == beforeGeneration { + return nil + } + + if err := c.InitializeProviders(); err != nil { + return fmt.Errorf("initializing providers from registration configuration: %w", err) + } + + if !c.IsClient() { + go func() { + if err := c.ReloadRoleIndexes(); err != nil { + logrus.WithError(err).Errorln("Failed to reload role indexes after registration configuration update") + } + }() + } + + return nil +} + +func (c *Config) getConfigGeneration() uint64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.configGeneration +} + +func (c *Config) PublishCurrentAgentRoute(ctx context.Context) error { + services := c.GetServices() + if services == nil || !services.HasTemporal() { + logrus.Debug("Skipping current device route publication because Temporal is unavailable") + return nil + } + + temporalService := services.GetTemporal() + if temporalService == nil || !temporalService.HasClient() { + logrus.Debug("Skipping current device route publication because the Temporal client is unavailable") + return nil + } + + return c.publishCurrentAgentRoute(ctx, c.PublishDeviceConnectionState) +} + +func (c *Config) publishCurrentAgentRoute( + ctx context.Context, + publish func(context.Context, models.DeviceConnectionState) error, +) error { + if !c.IsAgent() { + return fmt.Errorf("current device route publication is only valid in agent mode") + } + if publish == nil { + return fmt.Errorf("device route publisher is required") + } + + environment := c.GetEnvironmentConfig() + state := models.DeviceConnectionState{ + DeviceID: common.GetDeviceID().String(), + TaskQueue: environment.GetIdentifier(), + Name: c.GetEnvironment().Name, + Hostname: c.GetEnvironment().Hostname, + Platform: string(c.GetEnvironment().Platform), + } + + return publish(ctx, state) +} diff --git a/internal/config/device_bootstrap_test.go b/internal/config/device_bootstrap_test.go new file mode 100644 index 00000000..8c1b332e --- /dev/null +++ b/internal/config/device_bootstrap_test.go @@ -0,0 +1,138 @@ +package config + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/serverlessworkflow/sdk-go/v3/model" + "github.com/thand-io/agent/internal/common" + "github.com/thand-io/agent/internal/models" +) + +func TestBootstrapDeviceWithLoginServerMergesRemoteProviderDefinitions(t *testing.T) { + t.Parallel() + + registerResponse := RegistrationResponse{ + Success: true, + Providers: &ProviderDefinitionsConfig{ + Definitions: map[string]models.ProviderConfig{ + "oauth2-directory": { + Name: "Directory Login", + Description: "Remote OAuth2 provider", + Provider: "oauth2", + Enabled: true, + Config: &models.BasicConfig{ + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "auth_url": "https://auth.example.com/oauth2/auth", + "token_url": "https://auth.example.com/oauth2/token", + "redirect_url": "http://localhost/callback", + }, + }, + }, + }, + } + + loginServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/preflight": + var req PreflightRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decoding preflight request: %v", err) + } + if got, want := req.Identifier.String(), common.GetDeviceID().String(); got != want { + t.Fatalf("preflight identifier = %q, want %q", got, want) + } + w.WriteHeader(http.StatusOK) + case "/api/v1/register": + var req RegistrationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decoding registration request: %v", err) + } + if got, want := req.Identifier.String(), common.GetDeviceID().String(); got != want { + t.Fatalf("registration identifier = %q, want %q", got, want) + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(registerResponse); err != nil { + t.Fatalf("encoding registration response: %v", err) + } + case "/api/v1/sync": + w.WriteHeader(http.StatusOK) + default: + http.NotFound(w, r) + } + })) + defer loginServer.Close() + + cfg := &Config{ + mode: ModeAgent, + Login: models.LoginConfig{ + Endpoint: &model.Endpoint{ + EndpointConfig: &model.EndpointConfiguration{ + URI: &model.LiteralUri{Value: loginServer.URL}, + }, + }, + }, + Thand: models.ThandConfig{ + Endpoint: loginServer.URL, + }, + Providers: ProviderDefinitionsConfig{ + Definitions: map[string]models.ProviderConfig{ + "local-elevation": { + Name: "Local Elevation", + Description: "Local privilege elevation provider", + Provider: "local", + Enabled: true, + }, + }, + }, + } + + if err := cfg.InitializeProviders(); err != nil { + t.Fatalf("InitializeProviders() error = %v", err) + } + + if cfg.HasProvider("oauth2-directory") { + t.Fatal("oauth2-directory provider already present before bootstrap") + } + + if err := cfg.BootstrapDeviceWithLoginServer(); err != nil { + t.Fatalf("BootstrapDeviceWithLoginServer() error = %v", err) + } + + if !cfg.HasProvider("oauth2-directory") { + t.Fatal("oauth2-directory provider missing after bootstrap") + } +} + +func TestPublishCurrentAgentRouteUsesCanonicalDeviceIdentity(t *testing.T) { + t.Parallel() + + cfg := &Config{ + mode: ModeAgent, + Environment: models.EnvironmentConfig{ + Name: "workstation-alpha", + Hostname: "workstation-alpha.example.test", + Platform: models.Local, + }, + } + + var published models.DeviceConnectionState + err := cfg.publishCurrentAgentRoute(context.Background(), func(ctx context.Context, state models.DeviceConnectionState) error { + published = state + return nil + }) + if err != nil { + t.Fatalf("publishCurrentAgentRoute() error = %v", err) + } + + if got, want := published.DeviceID, common.GetDeviceID().String(); got != want { + t.Fatalf("DeviceID = %q, want %q", got, want) + } + if got, want := published.TaskQueue, "thand_local_workstation_alpha"; got != want { + t.Fatalf("TaskQueue = %q, want %q", got, want) + } +} diff --git a/internal/config/device_definition_registry.go b/internal/config/device_definition_registry.go new file mode 100644 index 00000000..f2e0532b --- /dev/null +++ b/internal/config/device_definition_registry.go @@ -0,0 +1,62 @@ +package config + +import ( + "fmt" + "strings" + + "github.com/thand-io/agent/internal/models" + "go.temporal.io/sdk/workflow" +) + +func deviceDefinitionRegistryWorkflow(ctx workflow.Context) error { + definitions := map[string]models.Device{} + + if err := workflow.SetQueryHandler(ctx, models.TemporalGetDeviceDefinitionQueryName, func(deviceID string) (*models.Device, error) { + deviceID = strings.TrimSpace(deviceID) + if deviceID == "" { + return nil, fmt.Errorf("device id is required") + } + + device, ok := definitions[deviceID] + if !ok { + return nil, fmt.Errorf("device %q is not configured", deviceID) + } + + deviceCopy := device + return &deviceCopy, nil + }); err != nil { + return err + } + + signalCh := workflow.GetSignalChannel(ctx, models.TemporalDeviceDefinitionUpsertSignalName) + for { + cancelled := false + selector := workflow.NewSelector(ctx) + selector.AddReceive(signalCh, func(c workflow.ReceiveChannel, _ bool) { + var device models.Device + c.Receive(ctx, &device) + + device = normalizeDeviceDefinition(device) + if device.ID == "" { + return + } + + existing, exists := definitions[device.ID] + if exists && !deviceDefinitionsEqual(existing, device) { + workflow.GetLogger(ctx).Warn("Ignoring conflicting device definition update", + "device_id", device.ID, + ) + return + } + + definitions[device.ID] = device + }) + selector.AddReceive(ctx.Done(), func(workflow.ReceiveChannel, bool) { + cancelled = true + }) + selector.Select(ctx) + if cancelled { + return ctx.Err() + } + } +} diff --git a/internal/config/device_definition_registry_test.go b/internal/config/device_definition_registry_test.go new file mode 100644 index 00000000..cd85a998 --- /dev/null +++ b/internal/config/device_definition_registry_test.go @@ -0,0 +1,91 @@ +package config + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thand-io/agent/internal/models" + "go.temporal.io/sdk/testsuite" +) + +func queryDeviceDefinitionEventually( + t *testing.T, + env *testsuite.TestWorkflowEnvironment, + deviceID string, + assertDevice func(models.Device), +) { + t.Helper() + + var poll func() + poll = func() { + value, err := env.QueryWorkflow(models.TemporalGetDeviceDefinitionQueryName, deviceID) + if err != nil && strings.Contains(err.Error(), "unknown queryType") { + env.RegisterDelayedCallback(poll, time.Millisecond) + return + } + require.NoError(t, err) + + var device models.Device + require.NoError(t, value.Get(&device)) + assertDevice(device) + + env.CancelWorkflow() + } + + env.RegisterDelayedCallback(poll, time.Millisecond) +} + +func TestDeviceDefinitionRegistryWorkflowReturnsConfiguredDevice(t *testing.T) { + t.Parallel() + + var suite testsuite.WorkflowTestSuite + env := suite.NewTestWorkflowEnvironment() + + env.RegisterDelayedCallback(func() { + env.SignalWorkflow(models.TemporalDeviceDefinitionUpsertSignalName, models.Device{ + ID: "device-alpha", + Name: "Device Alpha", + Enabled: true, + }) + }, 0) + + queryDeviceDefinitionEventually(t, env, "device-alpha", func(device models.Device) { + assert.Equal(t, "device-alpha", device.ID) + assert.Equal(t, "Device Alpha", device.Name) + }) + + env.ExecuteWorkflow(deviceDefinitionRegistryWorkflow) + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) +} + +func TestDeviceDefinitionRegistryWorkflowRejectsConflictingUpdates(t *testing.T) { + t.Parallel() + + var suite testsuite.WorkflowTestSuite + env := suite.NewTestWorkflowEnvironment() + + env.RegisterDelayedCallback(func() { + env.SignalWorkflow(models.TemporalDeviceDefinitionUpsertSignalName, models.Device{ + ID: "device-alpha", + Name: "Device Alpha", + Enabled: true, + }) + env.SignalWorkflow(models.TemporalDeviceDefinitionUpsertSignalName, models.Device{ + ID: "device-alpha", + Name: "Conflicting Device Alpha", + Enabled: true, + }) + }, 0) + + queryDeviceDefinitionEventually(t, env, "device-alpha", func(device models.Device) { + assert.Equal(t, "Device Alpha", device.Name) + }) + + env.ExecuteWorkflow(deviceDefinitionRegistryWorkflow) + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) +} diff --git a/internal/config/device_registry.go b/internal/config/device_registry.go new file mode 100644 index 00000000..2b31f1ae --- /dev/null +++ b/internal/config/device_registry.go @@ -0,0 +1,298 @@ +package config + +import ( + "context" + "errors" + "fmt" + "reflect" + "slices" + "strings" + "time" + + "github.com/sirupsen/logrus" + "github.com/thand-io/agent/internal/models" + "go.temporal.io/api/enums/v1" + "go.temporal.io/api/serviceerror" + workflowservice "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/client" +) + +const deviceRegistryQueryTimeout = 30 * time.Second + +type deviceRegistryTemporalClient interface { + DescribeWorkflowExecution(ctx context.Context, workflowID, runID string) (*workflowservice.DescribeWorkflowExecutionResponse, error) + QueryWorkflowWithOptions(ctx context.Context, request *client.QueryWorkflowWithOptionsRequest) (*client.QueryWorkflowWithOptionsResponse, error) + SignalWithStartWorkflow( + ctx context.Context, + workflowID string, + signalName string, + signalArg interface{}, + options client.StartWorkflowOptions, + workflow interface{}, + args ...interface{}, + ) (client.WorkflowRun, error) + TerminateWorkflow(ctx context.Context, workflowID, runID, reason string, details ...interface{}) error +} + +func deviceRegistryStartWorkflowOptions(workflowID string) client.StartWorkflowOptions { + // These internal singleton workflows always run on the shared + // device-registry queue, which is intentionally unversioned even when the + // operational server/agent queues use worker deployments. + return client.StartWorkflowOptions{ + ID: workflowID, + TaskQueue: models.TemporalDeviceRegistryTaskQueue, + } +} + +func registryWorkflowUsesVersioning(description *workflowservice.DescribeWorkflowExecutionResponse) bool { + if description == nil { + return false + } + + info := description.GetWorkflowExecutionInfo() + if info == nil { + return false + } + + if strings.TrimSpace(info.GetAssignedBuildId()) != "" || strings.TrimSpace(info.GetInheritedBuildId()) != "" { + return true + } + + versioningInfo := info.GetVersioningInfo() + if versioningInfo == nil { + return false + } + + if versioningInfo.GetBehavior() != enums.VERSIONING_BEHAVIOR_UNSPECIFIED { + return true + } + + return versioningInfo.GetVersioningOverride() != nil +} + +func normalizeDeviceDefinition(device models.Device) models.Device { + device.ID = strings.TrimSpace(device.ID) + device.Name = strings.TrimSpace(device.Name) + device.Description = strings.TrimSpace(device.Description) + device.Platform = strings.TrimSpace(device.Platform) + if device.LocalElevation != nil { + policy := *device.LocalElevation + policy.AllowedModes = slices.Clone(policy.AllowedModes) + policy.DeniedUsernames = slices.Clone(policy.DeniedUsernames) + policy.AllowedUIDRanges = slices.Clone(policy.AllowedUIDRanges) + if len(policy.Accounts) > 0 { + policy.Accounts = append([]models.DeviceLocalElevationAccount(nil), policy.Accounts...) + for i := range policy.Accounts { + policy.Accounts[i].Identity = strings.TrimSpace(policy.Accounts[i].Identity) + policy.Accounts[i].Email = strings.TrimSpace(policy.Accounts[i].Email) + policy.Accounts[i].Username = strings.TrimSpace(policy.Accounts[i].Username) + policy.Accounts[i].LocalUsername = strings.TrimSpace(policy.Accounts[i].LocalUsername) + } + } + for i := range policy.AllowedModes { + policy.AllowedModes[i] = strings.TrimSpace(policy.AllowedModes[i]) + } + for i := range policy.DeniedUsernames { + policy.DeniedUsernames[i] = strings.TrimSpace(policy.DeniedUsernames[i]) + } + for i := range policy.AllowedUIDRanges { + policy.AllowedUIDRanges[i] = strings.TrimSpace(policy.AllowedUIDRanges[i]) + } + device.LocalElevation = &policy + } + return device +} + +func deviceDefinitionsEqual(left, right models.Device) bool { + return reflect.DeepEqual(normalizeDeviceDefinition(left), normalizeDeviceDefinition(right)) +} + +func queryDeviceDefinition( + ctx context.Context, + temporalClient deviceRegistryTemporalClient, + deviceID string, +) (*models.Device, error) { + deviceID = strings.TrimSpace(deviceID) + if deviceID == "" { + return nil, fmt.Errorf("device id is required") + } + if temporalClient == nil { + return nil, fmt.Errorf("shared device registry is unavailable") + } + + timeoutCtx, cancel := context.WithTimeout(ctx, deviceRegistryQueryTimeout) + defer cancel() + + queryResponse, err := temporalClient.QueryWorkflowWithOptions(timeoutCtx, &client.QueryWorkflowWithOptionsRequest{ + WorkflowID: models.TemporalDeviceDefinitionRegistryWorkflowID, + RunID: "", + QueryType: models.TemporalGetDeviceDefinitionQueryName, + QueryRejectCondition: enums.QUERY_REJECT_CONDITION_NOT_OPEN, + Args: []any{deviceID}, + }) + if err != nil { + return nil, fmt.Errorf("device %q is not configured", deviceID) + } + if queryResponse == nil || queryResponse.QueryResult == nil { + return nil, fmt.Errorf("device %q is not configured", deviceID) + } + + var device models.Device + if err := queryResponse.QueryResult.Get(&device); err != nil { + return nil, err + } + + normalized := normalizeDeviceDefinition(device) + return &normalized, nil +} + +func ensureRegistryWorkflowTaskQueue( + ctx context.Context, + temporalClient deviceRegistryTemporalClient, + workflowID string, +) error { + if temporalClient == nil { + return fmt.Errorf("temporal client is required to manage registry workflow %q", workflowID) + } + + description, err := temporalClient.DescribeWorkflowExecution(ctx, workflowID, "") + if err != nil { + var notFound *serviceerror.NotFound + if errors.As(err, ¬Found) { + return nil + } + return err + } + + taskQueue := "" + if description != nil && description.ExecutionConfig != nil && description.ExecutionConfig.TaskQueue != nil { + taskQueue = strings.TrimSpace(description.ExecutionConfig.TaskQueue.Name) + } + if taskQueue == "" || taskQueue == models.TemporalDeviceRegistryTaskQueue { + if !registryWorkflowUsesVersioning(description) { + return nil + } + + logrus.WithFields(logrus.Fields{ + "workflow_id": workflowID, + "task_queue": models.TemporalDeviceRegistryTaskQueue, + }).Warn("Recreating versioned device registry workflow on the canonical unversioned device registry queue") + + return temporalClient.TerminateWorkflow(ctx, workflowID, "", "migrating device registry workflow to canonical unversioned queue") + } + + logrus.WithFields(logrus.Fields{ + "workflow_id": workflowID, + "task_queue": taskQueue, + "expected": models.TemporalDeviceRegistryTaskQueue, + }).Warn("Recreating device registry workflow on the canonical device registry queue") + + return temporalClient.TerminateWorkflow(ctx, workflowID, "", "migrating device registry workflow to canonical task queue") +} + +func publishDeviceDefinition( + ctx context.Context, + temporalClient deviceRegistryTemporalClient, + device models.Device, +) error { + if temporalClient == nil { + return fmt.Errorf("shared device registry is unavailable") + } + + device = normalizeDeviceDefinition(device) + if device.ID == "" { + return nil + } + + _, err := temporalClient.SignalWithStartWorkflow( + ctx, + models.TemporalDeviceDefinitionRegistryWorkflowID, + models.TemporalDeviceDefinitionUpsertSignalName, + device, + deviceRegistryStartWorkflowOptions(models.TemporalDeviceDefinitionRegistryWorkflowID), + models.TemporalDeviceDefinitionRegistryWorkflowName, + ) + return err +} + +func (c *Config) querySharedDeviceDefinition(ctx context.Context, deviceID string) (*models.Device, error) { + services := c.GetServices() + if services == nil || !services.HasTemporal() { + return nil, fmt.Errorf("shared device registry is unavailable") + } + + temporalService := services.GetTemporal() + if temporalService == nil || !temporalService.HasClient() { + return nil, fmt.Errorf("shared device registry is unavailable") + } + + return queryDeviceDefinition(ctx, temporalService.GetClient(), deviceID) +} + +func (c *Config) EnsureDeviceRegistryWorkflows(ctx context.Context) error { + if !c.IsServer() { + return nil + } + + services := c.GetServices() + if services == nil || !services.HasTemporal() { + return fmt.Errorf("temporal service is required to manage device registries") + } + + temporalService := services.GetTemporal() + if temporalService == nil || !temporalService.HasClient() { + return fmt.Errorf("temporal client is required to manage device registries") + } + + for _, workflowID := range []string{ + models.TemporalDeviceRouteRegistryWorkflowID, + models.TemporalDeviceDefinitionRegistryWorkflowID, + } { + if err := ensureRegistryWorkflowTaskQueue(ctx, temporalService.GetClient(), workflowID); err != nil { + return err + } + } + + return nil +} + +func (c *Config) PublishConfiguredDeviceDefinitions(ctx context.Context) error { + if !c.IsServer() { + return nil + } + + services := c.GetServices() + if services == nil || !services.HasTemporal() { + return fmt.Errorf("temporal service is required to publish device definitions") + } + + temporalService := services.GetTemporal() + if temporalService == nil || !temporalService.HasClient() { + return fmt.Errorf("temporal client is required to publish device definitions") + } + + deviceIDs := make([]string, 0, len(c.Devices.Definitions)) + for _, device := range c.Devices.Definitions { + deviceID := strings.TrimSpace(device.ID) + if deviceID == "" { + continue + } + deviceIDs = append(deviceIDs, deviceID) + } + slices.Sort(deviceIDs) + + for _, deviceID := range deviceIDs { + device, err := c.GetDevice(deviceID) + if err != nil { + return err + } + if device == nil { + continue + } + if err := publishDeviceDefinition(ctx, temporalService.GetClient(), *device); err != nil { + return err + } + } + + return nil +} diff --git a/internal/config/device_registry_test.go b/internal/config/device_registry_test.go new file mode 100644 index 00000000..d1b7bf13 --- /dev/null +++ b/internal/config/device_registry_test.go @@ -0,0 +1,139 @@ +package config + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thand-io/agent/internal/models" + enumspb "go.temporal.io/api/enums/v1" + taskqueuepb "go.temporal.io/api/taskqueue/v1" + workflowpb "go.temporal.io/api/workflow/v1" + workflowservice "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/converter" +) + +type fakeDeviceRegistryClient struct { + describeResponse *workflowservice.DescribeWorkflowExecutionResponse + describeErr error + queryResponse *client.QueryWorkflowWithOptionsResponse + queryErr error + terminated []string + signalOptions []client.StartWorkflowOptions + signalNames []string + signalArgs []any +} + +func (f *fakeDeviceRegistryClient) DescribeWorkflowExecution(ctx context.Context, workflowID, runID string) (*workflowservice.DescribeWorkflowExecutionResponse, error) { + return f.describeResponse, f.describeErr +} + +func (f *fakeDeviceRegistryClient) QueryWorkflowWithOptions(ctx context.Context, request *client.QueryWorkflowWithOptionsRequest) (*client.QueryWorkflowWithOptionsResponse, error) { + return f.queryResponse, f.queryErr +} + +func (f *fakeDeviceRegistryClient) SignalWithStartWorkflow( + ctx context.Context, + workflowID string, + signalName string, + signalArg interface{}, + options client.StartWorkflowOptions, + workflow interface{}, + args ...interface{}, +) (client.WorkflowRun, error) { + f.signalOptions = append(f.signalOptions, options) + f.signalNames = append(f.signalNames, signalName) + f.signalArgs = append(f.signalArgs, signalArg) + return nil, nil +} + +func (f *fakeDeviceRegistryClient) TerminateWorkflow(ctx context.Context, workflowID, runID, reason string, details ...interface{}) error { + f.terminated = append(f.terminated, workflowID) + return nil +} + +func TestEnsureRegistryWorkflowTaskQueueTerminatesWrongQueue(t *testing.T) { + t.Parallel() + + client := &fakeDeviceRegistryClient{ + describeResponse: &workflowservice.DescribeWorkflowExecutionResponse{ + ExecutionConfig: &workflowpb.WorkflowExecutionConfig{ + TaskQueue: &taskqueuepb.TaskQueue{Name: "thand_local_old_server"}, + }, + }, + } + + err := ensureRegistryWorkflowTaskQueue(context.Background(), client, models.TemporalDeviceRouteRegistryWorkflowID) + require.NoError(t, err) + assert.Equal(t, []string{models.TemporalDeviceRouteRegistryWorkflowID}, client.terminated) +} + +func TestEnsureRegistryWorkflowTaskQueueTerminatesVersionedRegistryWorkflow(t *testing.T) { + t.Parallel() + + client := &fakeDeviceRegistryClient{ + describeResponse: &workflowservice.DescribeWorkflowExecutionResponse{ + ExecutionConfig: &workflowpb.WorkflowExecutionConfig{ + TaskQueue: &taskqueuepb.TaskQueue{Name: models.TemporalDeviceRegistryTaskQueue}, + }, + WorkflowExecutionInfo: &workflowpb.WorkflowExecutionInfo{ + VersioningInfo: &workflowpb.WorkflowExecutionVersioningInfo{ + Behavior: enumspb.VERSIONING_BEHAVIOR_AUTO_UPGRADE, + }, + }, + }, + } + + err := ensureRegistryWorkflowTaskQueue(context.Background(), client, models.TemporalDeviceDefinitionRegistryWorkflowID) + require.NoError(t, err) + assert.Equal(t, []string{models.TemporalDeviceDefinitionRegistryWorkflowID}, client.terminated) +} + +func TestPublishDeviceDefinitionUsesCanonicalRegistryQueue(t *testing.T) { + t.Parallel() + + fakeClient := &fakeDeviceRegistryClient{} + err := publishDeviceDefinition(context.Background(), fakeClient, models.Device{ + ID: "device-alpha", + Name: "Device Alpha", + Enabled: true, + }) + require.NoError(t, err) + require.Len(t, fakeClient.signalOptions, 1) + assert.Equal(t, models.TemporalDeviceRegistryTaskQueue, fakeClient.signalOptions[0].TaskQueue) + assert.Equal(t, models.TemporalDeviceDefinitionUpsertSignalName, fakeClient.signalNames[0]) + assert.Nil(t, fakeClient.signalOptions[0].VersioningOverride) +} + +func TestDeviceRegistryStartWorkflowOptionsOmitsVersioningOverride(t *testing.T) { + t.Parallel() + + opts := deviceRegistryStartWorkflowOptions(models.TemporalDeviceDefinitionRegistryWorkflowID) + assert.Equal(t, models.TemporalDeviceRegistryTaskQueue, opts.TaskQueue) + assert.Nil(t, opts.VersioningOverride) +} + +func TestQueryDeviceDefinitionReturnsStoredDevice(t *testing.T) { + t.Parallel() + + payloads, err := converter.GetDefaultDataConverter().ToPayloads(models.Device{ + ID: "device-alpha", + Name: "Device Alpha", + Enabled: true, + }) + require.NoError(t, err) + + client := &fakeDeviceRegistryClient{ + queryResponse: &client.QueryWorkflowWithOptionsResponse{ + QueryResult: client.NewValue(payloads), + }, + } + + device, err := queryDeviceDefinition(context.Background(), client, "device-alpha") + require.NoError(t, err) + require.NotNil(t, device) + assert.Equal(t, "device-alpha", device.ID) + assert.Equal(t, "Device Alpha", device.Name) +} diff --git a/internal/config/device_route_registry.go b/internal/config/device_route_registry.go new file mode 100644 index 00000000..1e9afb52 --- /dev/null +++ b/internal/config/device_route_registry.go @@ -0,0 +1,101 @@ +package config + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/thand-io/agent/internal/models" + "go.temporal.io/sdk/workflow" +) + +const deviceRouteRegistryTickInterval = 30 * time.Second + +func deviceRouteRegistryWorkflow(ctx workflow.Context) error { + routes := map[string]models.DeviceConnectionState{} + + if err := workflow.SetQueryHandler(ctx, models.TemporalGetDeviceRouteQueryName, func(deviceID string) (*models.DeviceConnectionState, error) { + deviceID = strings.TrimSpace(deviceID) + if deviceID == "" { + return nil, fmt.Errorf("device id is required") + } + + route, ok := routes[deviceID] + if !ok { + return nil, fmt.Errorf("%w: device %q is not connected", ErrDeviceRouteUnavailable, deviceID) + } + + route.Connected = route.TaskQueue != "" && !route.LastSeenAt.IsZero() && workflow.Now(ctx).Sub(route.LastSeenAt) <= models.DeviceRouteFreshnessTTL + if !route.Connected { + return nil, fmt.Errorf("%w: device %q is not connected", ErrDeviceRouteUnavailable, deviceID) + } + + routeCopy := route + return &routeCopy, nil + }); err != nil { + return err + } + + signalCh := workflow.GetSignalChannel(ctx, models.TemporalDeviceRouteUpsertSignalName) + for { + cancelled := false + selector := workflow.NewSelector(ctx) + selector.AddReceive(signalCh, func(c workflow.ReceiveChannel, _ bool) { + var route models.DeviceConnectionState + c.Receive(ctx, &route) + route.DeviceID = strings.TrimSpace(route.DeviceID) + if route.DeviceID == "" { + return + } + route.TaskQueue = strings.TrimSpace(route.TaskQueue) + if route.LastSeenAt.IsZero() { + route.LastSeenAt = workflow.Now(ctx) + } + route.Connected = route.TaskQueue != "" && workflow.Now(ctx).Sub(route.LastSeenAt) <= models.DeviceRouteFreshnessTTL + routes[route.DeviceID] = route + }) + selector.AddReceive(ctx.Done(), func(workflow.ReceiveChannel, bool) { + cancelled = true + }) + selector.AddFuture(workflow.NewTimer(ctx, deviceRouteRegistryTickInterval), func(workflow.Future) {}) + selector.Select(ctx) + if cancelled { + return ctx.Err() + } + } +} + +func (c *Config) PublishDeviceConnectionState(ctx context.Context, state models.DeviceConnectionState) error { + c.SetDeviceConnectionState(state) + + services := c.GetServices() + if services == nil || !services.HasTemporal() { + return nil + } + + temporalService := services.GetTemporal() + if temporalService == nil || !temporalService.HasClient() { + return nil + } + + state.DeviceID = strings.TrimSpace(state.DeviceID) + state.TaskQueue = strings.TrimSpace(state.TaskQueue) + if state.DeviceID == "" || state.TaskQueue == "" { + return nil + } + if state.LastSeenAt.IsZero() { + state.LastSeenAt = time.Now().UTC() + } + state.Connected = true + + _, err := temporalService.GetClient().SignalWithStartWorkflow( + ctx, + models.TemporalDeviceRouteRegistryWorkflowID, + models.TemporalDeviceRouteUpsertSignalName, + state, + deviceRegistryStartWorkflowOptions(models.TemporalDeviceRouteRegistryWorkflowID), + models.TemporalDeviceRouteRegistryWorkflowName, + ) + return err +} diff --git a/internal/config/device_route_registry_test.go b/internal/config/device_route_registry_test.go new file mode 100644 index 00000000..b038ab1f --- /dev/null +++ b/internal/config/device_route_registry_test.go @@ -0,0 +1,76 @@ +package config + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thand-io/agent/internal/models" + "go.temporal.io/sdk/testsuite" +) + +func TestDeviceRouteRegistryWorkflowReturnsFreshRouteByDeviceID(t *testing.T) { + var suite testsuite.WorkflowTestSuite + env := suite.NewTestWorkflowEnvironment() + + env.RegisterDelayedCallback(func() { + env.SignalWorkflow(models.TemporalDeviceRouteUpsertSignalName, models.DeviceConnectionState{ + DeviceID: "device-alpha", + TaskQueue: "thand-local-alpha", + Name: "Device Alpha", + Hostname: "host-one", + Platform: "local", + }) + }, time.Second) + + env.RegisterDelayedCallback(func() { + value, err := env.QueryWorkflow(models.TemporalGetDeviceRouteQueryName, "device-alpha") + require.NoError(t, err) + + var route models.DeviceConnectionState + require.NoError(t, value.Get(&route)) + assert.Equal(t, "thand-local-alpha", route.TaskQueue) + assert.Equal(t, "host-one", route.Hostname) + + env.CancelWorkflow() + }, 2*time.Second) + + env.ExecuteWorkflow(deviceRouteRegistryWorkflow) + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) +} + +func TestDeviceRouteRegistryWorkflowUsesHostnameAsMetadataOnly(t *testing.T) { + var suite testsuite.WorkflowTestSuite + env := suite.NewTestWorkflowEnvironment() + + env.RegisterDelayedCallback(func() { + env.SignalWorkflow(models.TemporalDeviceRouteUpsertSignalName, models.DeviceConnectionState{ + DeviceID: "device-alpha", + TaskQueue: "thand-local-alpha", + Hostname: "host-one", + }) + env.SignalWorkflow(models.TemporalDeviceRouteUpsertSignalName, models.DeviceConnectionState{ + DeviceID: "device-alpha", + TaskQueue: "thand-local-alpha", + Hostname: "host-two", + }) + }, time.Second) + + env.RegisterDelayedCallback(func() { + value, err := env.QueryWorkflow(models.TemporalGetDeviceRouteQueryName, "device-alpha") + require.NoError(t, err) + + var route models.DeviceConnectionState + require.NoError(t, value.Get(&route)) + assert.Equal(t, "thand-local-alpha", route.TaskQueue) + assert.Equal(t, "host-two", route.Hostname) + + env.CancelWorkflow() + }, 2*time.Second) + + env.ExecuteWorkflow(deviceRouteRegistryWorkflow) + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) +} diff --git a/internal/config/devices.go b/internal/config/devices.go new file mode 100644 index 00000000..485e3a61 --- /dev/null +++ b/internal/config/devices.go @@ -0,0 +1,111 @@ +package config + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/thand-io/agent/internal/models" +) + +var ErrDeviceRouteUnavailable = errors.New("device route unavailable") + +const ( + deviceRouteRefreshInterval = models.DeviceRouteRefreshInterval + deviceRouteFreshnessTTL = models.DeviceRouteFreshnessTTL +) + +func (c *Config) GetDevice(deviceID string) (*models.Device, error) { + deviceID = strings.TrimSpace(deviceID) + if deviceID == "" { + return nil, fmt.Errorf("device id is required") + } + + for _, device := range c.Devices.Definitions { + configuredDeviceID := strings.TrimSpace(device.ID) + if configuredDeviceID == "" { + continue + } + if strings.EqualFold(configuredDeviceID, deviceID) { + deviceCopy := device + return &deviceCopy, nil + } + } + + return nil, fmt.Errorf("device %q is not configured", deviceID) +} + +func (c *Config) SetDeviceConnectionState(state models.DeviceConnectionState) { + if strings.TrimSpace(state.DeviceID) == "" { + return + } + + state.DeviceID = strings.TrimSpace(state.DeviceID) + state.TaskQueue = strings.TrimSpace(state.TaskQueue) + if state.LastSeenAt.IsZero() { + state.LastSeenAt = time.Now().UTC() + } + state.Connected = c.isFreshDeviceConnectionState(&state) + + c.deviceConnectionsMu.Lock() + defer c.deviceConnectionsMu.Unlock() + + if c.deviceConnections == nil { + c.deviceConnections = make(map[string]*models.DeviceConnectionState) + } + + stateCopy := state + c.deviceConnections[state.DeviceID] = &stateCopy +} + +func (c *Config) GetDeviceConnectionState(deviceID string) *models.DeviceConnectionState { + deviceID = strings.TrimSpace(deviceID) + if deviceID == "" { + return nil + } + + c.deviceConnectionsMu.RLock() + defer c.deviceConnectionsMu.RUnlock() + + state, ok := c.deviceConnections[deviceID] + if !ok || state == nil { + return nil + } + + stateCopy := *state + stateCopy.Connected = c.isFreshDeviceConnectionState(&stateCopy) + return &stateCopy +} + +func (c *Config) GetFreshDeviceRoute(deviceID string) (*models.DeviceConnectionState, error) { + device, err := c.GetDevice(deviceID) + if err != nil { + return nil, err + } + if !device.Enabled { + return nil, fmt.Errorf("%w: device %q is disabled", ErrDeviceRouteUnavailable, device.ID) + } + connectionState := c.GetDeviceConnectionState(device.ID) + if connectionState == nil || !connectionState.Connected { + return nil, fmt.Errorf("%w: device %q is not connected", ErrDeviceRouteUnavailable, device.ID) + } + if strings.TrimSpace(connectionState.TaskQueue) == "" { + return nil, fmt.Errorf("%w: device %q has no live task queue", ErrDeviceRouteUnavailable, device.ID) + } + + return connectionState, nil +} + +func (c *Config) isFreshDeviceConnectionState(state *models.DeviceConnectionState) bool { + if state == nil { + return false + } + if strings.TrimSpace(state.TaskQueue) == "" { + return false + } + if state.LastSeenAt.IsZero() { + return false + } + return time.Since(state.LastSeenAt) <= deviceRouteFreshnessTTL +} diff --git a/internal/config/devices_test.go b/internal/config/devices_test.go new file mode 100644 index 00000000..09c5889c --- /dev/null +++ b/internal/config/devices_test.go @@ -0,0 +1,117 @@ +package config + +import ( + "strings" + "testing" + "time" + + "github.com/thand-io/agent/internal/models" +) + +func TestGetFreshDeviceRouteUsesConnectedTaskQueue(t *testing.T) { + cfg := &Config{ + Devices: DeviceDefinitionsConfig{ + Definitions: map[string]models.Device{ + "device-alpha": { + ID: "device-alpha", + Name: "Device Alpha", + Enabled: true, + }, + }, + }, + } + + cfg.SetDeviceConnectionState(models.DeviceConnectionState{ + DeviceID: "device-alpha", + TaskQueue: "connected-queue", + }) + + route, err := cfg.GetFreshDeviceRoute("device-alpha") + if err != nil { + t.Fatalf("GetFreshDeviceRoute returned error: %v", err) + } + + if got, want := route.TaskQueue, "connected-queue"; got != want { + t.Fatalf("task queue = %q, want %q", got, want) + } + if !route.Connected { + t.Fatal("expected fresh route to be marked connected") + } +} + +func TestGetFreshDeviceRouteRejectsStaleConnection(t *testing.T) { + cfg := &Config{ + Devices: DeviceDefinitionsConfig{ + Definitions: map[string]models.Device{ + "device-alpha": { + ID: "device-alpha", + Name: "Device Alpha", + Enabled: true, + }, + }, + }, + } + + cfg.SetDeviceConnectionState(models.DeviceConnectionState{ + DeviceID: "device-alpha", + TaskQueue: "connected-queue", + LastSeenAt: time.Now().UTC().Add(-deviceRouteFreshnessTTL - time.Second), + }) + + _, err := cfg.GetFreshDeviceRoute("device-alpha") + if err == nil { + t.Fatal("expected stale route error") + } + if !strings.Contains(err.Error(), `device "device-alpha" is not connected`) { + t.Fatalf("unexpected error: %v", err) + } + + state := cfg.GetDeviceConnectionState("device-alpha") + if state == nil { + t.Fatal("expected stored connection state") + } + if state.Connected { + t.Fatal("expected stale connection state to be marked disconnected") + } +} + +func TestGetDeviceUsesCanonicalDeviceID(t *testing.T) { + cfg := &Config{ + Devices: DeviceDefinitionsConfig{ + Definitions: map[string]models.Device{ + "workstation-alpha": { + ID: "device-alpha", + Name: "Device Alpha", + Enabled: true, + }, + }, + }, + } + + device, err := cfg.GetDevice("device-alpha") + if err != nil { + t.Fatalf("GetDevice returned error: %v", err) + } + + if got, want := device.ID, "device-alpha"; got != want { + t.Fatalf("device id = %q, want %q", got, want) + } +} + +func TestGetDeviceDoesNotTreatMapKeyAsIdentity(t *testing.T) { + cfg := &Config{ + Devices: DeviceDefinitionsConfig{ + Definitions: map[string]models.Device{ + "workstation-alpha": { + ID: "device-alpha", + Name: "Device Alpha", + Enabled: true, + }, + }, + }, + } + + if _, err := cfg.GetDevice("workstation-alpha"); err == nil { + t.Fatal("expected GetDevice to reject YAML map key as device identity") + } +} diff --git a/internal/config/model.go b/internal/config/model.go index d009e1a3..eb432a39 100644 --- a/internal/config/model.go +++ b/internal/config/model.go @@ -54,6 +54,7 @@ type Config struct { Roles RoleConfig `mapstructure:"roles"` Workflows WorkflowConfig `mapstructure:"workflows"` // These are workflows to run for role associated workflows Providers ProviderDefinitionsConfig `mapstructure:"providers"` // These are integration providers like AWS, GCP, etc. + Devices DeviceDefinitionsConfig `mapstructure:"devices"` // Device definitions and per-device policy managed by the server // This is ONLY if the agent is running in server mode // and you want to use https://www.thand.io hosted services @@ -76,6 +77,11 @@ type Config struct { // Provider instances providerInstances map[string]models.Provider + providerBindings map[string]struct{} + + // Device runtime state. + deviceConnections map[string]*models.DeviceConnectionState + deviceConnectionsMu sync.RWMutex } func (c *Config) GetSecret() string { @@ -123,6 +129,10 @@ func (c *Config) GetProvidersConfig() *ProviderDefinitionsConfig { return &c.Providers } +func (c *Config) GetDevicesConfig() *DeviceDefinitionsConfig { + return &c.Devices +} + func (c *Config) GetThandConfig() *models.ThandConfig { return &c.Thand } @@ -238,6 +248,16 @@ func (p *ProviderDefinitionsConfig) GetDefinitions() map[string]models.ProviderC return p.Definitions } +type DeviceDefinitionsConfig struct { + Path string `mapstructure:"path" json:"path"` + + Definitions map[string]models.Device `mapstructure:",remain" json:"definitions"` +} + +func (d *DeviceDefinitionsConfig) GetDefinitions() map[string]models.Device { + return d.Definitions +} + type ProviderPluginConfig struct { Path string `mapstructure:"path"` URL string `mapstructure:"url"` diff --git a/internal/config/providers.go b/internal/config/providers.go index d07358df..7cd025bd 100644 --- a/internal/config/providers.go +++ b/internal/config/providers.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "sync" "github.com/hashicorp/go-version" "github.com/sirupsen/logrus" @@ -31,6 +32,8 @@ import ( _ "github.com/thand-io/agent/internal/providers/thand" ) +var providerBindingsMu sync.Mutex + // LoadProviders loads providers from a file or URL and maps them to their implementations func (c *Config) LoadProviders() (map[string]models.ProviderConfig, error) { @@ -235,120 +238,17 @@ func (c *Config) InitializeProviders() error { models.ProviderCapabilityTenants, ) { - logrus.Infoln("Provider", result.key, "supports RBAC/Identities capabilities") - - // Register provider workflows and activities with Temporal if available - if c.IsServer() { + logrus.Infoln("Provider", result.key, "supports synchronization or provisioning capabilities") - if c.GetServices() != nil && c.GetServices().HasTemporal() { - - logrus.Infoln("Registering Temporal workflows/activities for provider", result.key) - - temporalService := c.GetServices().GetTemporal() - - worker := temporalService.GetWorker() - - if worker == nil { - logrus.Errorln("Temporal client is configured but worker is nil, cannot register workflows/activities for provider", result.key) - continue - } - - syncWorkflowName := models.CreateTemporalProviderWorkflowName( - providerResult.GetIdentifier(), - models.TemporalSynchronizeWorkflowName, - ) - - logrus.WithFields(logrus.Fields{ - "workflow": syncWorkflowName, - }).Infoln("Registering provider synchronize workflow with name", syncWorkflowName) - - // Register the provider Synchronize workflow. This updates roles, permissions, - // resources and identities for RBAC. We register this on the provider itself since it's a core part of the provider's functionality, but we register all other workflows and activities separately to allow providers to opt out of Temporal if they want. - worker.RegisterWorkflowWithOptions( - models.CreateProviderSynchronizeWorkflow(providerResult), - workflow.RegisterOptions{ - Name: syncWorkflowName, - VersioningBehavior: workflow.VersioningBehaviorPinned, - }, - ) - - if providerResult.HasCapability(models.ProviderCapabilityProvisioning) { - - authWorkflowName := models.CreateTemporalProviderWorkflowName( - providerResult.GetIdentifier(), - models.TemporalAuthorizeRoleWorkflowName) - - logrus.WithFields(logrus.Fields{ - "workflow": authWorkflowName, - "provider": providerResult.GetIdentifier(), - }).Infoln("Registering provider authorize role workflow with name", authWorkflowName) - - // Register the provider-specific authorize and revoke role workflows. - // These are closure-based: they capture the live provider instance so the - // child workflow can call provider.AuthorizeRole / RevokeRole with a - // full workflow.Context, allowing providers to dispatch activities, - // use workflow.Go, etc. - worker.RegisterWorkflowWithOptions( - models.CreateProviderAuthorizeRoleWorkflow(c, providerResult), - workflow.RegisterOptions{ - Name: authWorkflowName, - VersioningBehavior: workflow.VersioningBehaviorPinned, - }, - ) - - revokeWorkflowName := models.CreateTemporalProviderWorkflowName( - providerResult.GetIdentifier(), - models.TemporalRevokeRoleWorkflowName) - - logrus.WithFields(logrus.Fields{ - "workflow": revokeWorkflowName, - "provider": providerResult.GetIdentifier(), - }).Infoln("Registering provider revoke role workflow with name", revokeWorkflowName) - - worker.RegisterWorkflowWithOptions( - models.CreateProviderRevokeRoleWorkflow(c, providerResult), - workflow.RegisterOptions{ - Name: revokeWorkflowName, - VersioningBehavior: workflow.VersioningBehaviorPinned, - }, - ) - } - - // Register all custom provider workflows - workflowsRegistry := providerResult.RegisterWorkflows() - if workflowsRegistry != nil { - logrus.Infoln("Registering Temporal workflows for provider", result.key) - worker.RegisterWorkflow(workflowsRegistry) - } - - // Register default provider activities - err := models.RegisterProviderActivities(temporalService, providerResult) - if err != nil { - logrus.WithError(err).Errorln("Failed to register default activities for provider:", result.key) - continue - } - - customActivities := providerResult.RegisterActivities() - if customActivities != nil { - // Now register any custom activities defined by the provider - err = models.RegisterActivities( - temporalService, - providerResult.GetIdentifier(), - customActivities, - ) - if err != nil { - logrus.WithError(err).Errorln("Failed to register custom activities for provider:", result.key) - continue - } - } - } + if err := c.registerProviderTemporalBindings(providerResult); err != nil { + logrus.WithError(err).Errorln("Failed to register Temporal bindings for provider:", result.key) + continue + } + if c.IsServer() { logrus.Infoln("Synchronizing provider", result.key) c.synchronizeProvider(result.provider) - } else { - logrus.Infoln("Skipping Temporal registration for provider", result.key, "in non-server mode") - // Non-server mode: provider won't be synchronized, mark ready immediately providerResult.SetReady() } } else { @@ -370,6 +270,121 @@ func (c *Config) InitializeProviders() error { return nil } +func (c *Config) registerProviderTemporalBindings(providerResult models.Provider) error { + if providerResult == nil { + return fmt.Errorf("provider is nil") + } + if c.GetServices() == nil || !c.GetServices().HasTemporal() { + logrus.WithFields(logrus.Fields{ + "provider": providerResult.GetIdentifier(), + "mode": c.GetMode(), + }).Info("Skipping provider Temporal registration because Temporal is unavailable") + return nil + } + + providerBindingsMu.Lock() + defer providerBindingsMu.Unlock() + + if c.providerBindings == nil { + c.providerBindings = map[string]struct{}{} + } + if _, exists := c.providerBindings[providerResult.GetIdentifier()]; exists { + return nil + } + + // Provider bindings should stay on operational workers and never leak onto + // the shared device-registry queue. + temporalService := c.getOperationalTemporalService() + worker := temporalService.GetWorker() + if worker == nil { + return fmt.Errorf("temporal client is configured but worker is nil") + } + + syncWorkflowName := models.CreateTemporalProviderWorkflowName( + providerResult.GetIdentifier(), + models.TemporalSynchronizeWorkflowName, + ) + + worker.RegisterWorkflowWithOptions( + models.CreateProviderSynchronizeWorkflow(providerResult), + workflow.RegisterOptions{ + Name: syncWorkflowName, + VersioningBehavior: workflow.VersioningBehaviorPinned, + }, + ) + + if providerResult.HasCapability(models.ProviderCapabilityProvisioning) { + authWorkflowName := models.CreateTemporalProviderWorkflowName( + providerResult.GetIdentifier(), + models.TemporalAuthorizeRoleWorkflowName, + ) + worker.RegisterWorkflowWithOptions( + models.CreateProviderAuthorizeRoleWorkflow(c, providerResult), + workflow.RegisterOptions{ + Name: authWorkflowName, + VersioningBehavior: workflow.VersioningBehaviorPinned, + }, + ) + + revokeWorkflowName := models.CreateTemporalProviderWorkflowName( + providerResult.GetIdentifier(), + models.TemporalRevokeRoleWorkflowName, + ) + worker.RegisterWorkflowWithOptions( + models.CreateProviderRevokeRoleWorkflow(c, providerResult), + workflow.RegisterOptions{ + Name: revokeWorkflowName, + VersioningBehavior: workflow.VersioningBehaviorPinned, + }, + ) + } + + if workflowsRegistry := providerResult.RegisterWorkflows(); workflowsRegistry != nil { + worker.RegisterWorkflow(workflowsRegistry) + } + + if err := models.RegisterProviderActivities(temporalService, providerResult); err != nil { + return err + } + + if customActivities := providerResult.RegisterActivities(); customActivities != nil { + if err := models.RegisterActivities(temporalService, providerResult.GetIdentifier(), customActivities); err != nil { + return err + } + } + + c.providerBindings[providerResult.GetIdentifier()] = struct{}{} + return nil +} + +func (c *Config) EnsureProviderTemporalBindings() error { + c.mu.RLock() + providers := make([]models.Provider, 0, len(c.providerInstances)) + for _, provider := range c.providerInstances { + providers = append(providers, provider) + } + c.mu.RUnlock() + + for _, provider := range providers { + if !provider.HasAnyCapability( + models.ProviderCapabilityIdentities, + models.ProviderCapabilityUsers, + models.ProviderCapabilityGroups, + models.ProviderCapabilityResources, + models.ProviderCapabilityRoles, + models.ProviderCapabilityPermissions, + models.ProviderCapabilityTenants, + ) { + continue + } + if err := c.registerProviderTemporalBindings(provider); err != nil { + return err + } + } + + return nil +} + // initializeSingleProvider initializes a single provider func (c *Config) initializeSingleProvider(providerKey string, p *models.ProviderConfig) (models.Provider, error) { @@ -438,6 +453,10 @@ func (c *Config) GetProviders() ProviderDefinitionsConfig { return c.Providers } +func (c *Config) GetProviderDefinitions() map[string]models.ProviderConfig { + return c.Providers.Definitions +} + func (c *Config) GetProvider(providerName string) (string, models.Provider, error) { // Get the first provider by provider name @@ -504,7 +523,7 @@ func (c *Config) GetProvidersByCapabilityWithUser(user *models.User, capability if len(capability) != 0 && !provider.HasAnyCapability(capability...) { logrus.WithFields(logrus.Fields{ "capabilities": provider.GetCapabilities(), - }).Debugln("Skipping provider", name, "due to missing capability:", capability) + }).Traceln("Skipping provider", name, "due to missing capability:", capability) continue } diff --git a/internal/config/services.go b/internal/config/services.go index 60c3c591..ce5d58f0 100644 --- a/internal/config/services.go +++ b/internal/config/services.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "github.com/sirupsen/logrus" @@ -37,8 +38,8 @@ func (c *Config) SetupTemporal() error { logrus.Infoln("Setting up temporal services...") - if !c.IsServer() { - return fmt.Errorf("temporal services can only be set up in server mode") + if !c.IsServer() && !c.IsAgent() { + return fmt.Errorf("temporal services can only be set up in server or agent mode") } // Register workflows @@ -53,6 +54,19 @@ func (c *Config) SetupTemporal() error { return fmt.Errorf("registering temporal activities: %w", err) } + if err := c.EnsureProviderTemporalBindings(); err != nil { + return fmt.Errorf("registering provider temporal bindings: %w", err) + } + + if c.IsServer() { + if err := c.EnsureDeviceRegistryWorkflows(context.Background()); err != nil { + return fmt.Errorf("ensuring device registries: %w", err) + } + if err := c.PublishConfiguredDeviceDefinitions(context.Background()); err != nil { + return fmt.Errorf("publishing device definitions: %w", err) + } + } + return nil } diff --git a/internal/config/services/client.go b/internal/config/services/client.go index c3c36204..e29435ad 100644 --- a/internal/config/services/client.go +++ b/internal/config/services/client.go @@ -410,6 +410,11 @@ func (e *localClient) ReloadTemporal() error { logrus.WithField("identities", identities).Info("Configuring Temporal workers for agent mode") } + if e.config.IsServer() { + identities = append(identities, models.TemporalDeviceRegistryTaskQueue) + logrus.WithField("identities", identities).Info("Configuring Temporal workers for server mode") + } + // Get Temporal config from services servicesConfig := e.config.GetServicesConfig() diff --git a/internal/config/temporal.go b/internal/config/temporal.go index 4078b5cf..59bacfee 100644 --- a/internal/config/temporal.go +++ b/internal/config/temporal.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "github.com/thand-io/agent/internal/models" "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/workflow" ) // Register temporal workflows and activities @@ -15,12 +16,32 @@ func (c *Config) registerTemporalWorkflows() error { return fmt.Errorf("temporal service is not initialized") } - temporalWorker := c.servicesClient.GetTemporal().GetWorker() + if !c.IsServer() { + return nil + } - if temporalWorker == nil { - return fmt.Errorf("temporal worker is not initialized") + // Registry singletons live on the shared device-registry queue rather than + // the per-server operational queue. + registryWorker := c.getDeviceRegistryWorker() + if registryWorker == nil { + return fmt.Errorf("device registry worker is not initialized") } + registryWorker.RegisterWorkflowWithOptions( + deviceRouteRegistryWorkflow, + workflow.RegisterOptions{ + Name: models.TemporalDeviceRouteRegistryWorkflowName, + VersioningBehavior: workflow.VersioningBehaviorAutoUpgrade, + }, + ) + registryWorker.RegisterWorkflowWithOptions( + deviceDefinitionRegistryWorkflow, + workflow.RegisterOptions{ + Name: models.TemporalDeviceDefinitionRegistryWorkflowName, + VersioningBehavior: workflow.VersioningBehaviorAutoUpgrade, + }, + ) + return nil } @@ -31,7 +52,7 @@ func (c *Config) registerTemporalActivities() error { return fmt.Errorf("temporal service is not initialized") } - temporalWorker := c.servicesClient.GetTemporal().GetWorker() + temporalWorker := c.getOperationalTemporalWorker() if temporalWorker == nil { return fmt.Errorf("temporal worker is not initialized") @@ -64,6 +85,13 @@ func (c *Config) registerTemporalActivities() error { ) } + temporalWorker.RegisterActivityWithOptions( + thandActivities.ResolveFreshDeviceRoute, + activity.RegisterOptions{ + Name: models.TemporalResolveFreshDeviceRouteActivityName, + }, + ) + return nil } diff --git a/internal/config/temporal_activities.go b/internal/config/temporal_activities.go index 43df3169..133f3cd1 100644 --- a/internal/config/temporal_activities.go +++ b/internal/config/temporal_activities.go @@ -2,13 +2,16 @@ package config import ( "context" + "errors" "fmt" "strings" "github.com/serverlessworkflow/sdk-go/v3/model" "github.com/sirupsen/logrus" "github.com/thand-io/agent/internal/models" + "go.temporal.io/api/enums/v1" "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/client" "go.temporal.io/sdk/temporal" ) @@ -85,3 +88,63 @@ func (t *thandActivities) PatchProviderUpstream( return err } + +func (t *thandActivities) ResolveFreshDeviceRoute( + ctx context.Context, + deviceID string, +) (*models.DeviceConnectionState, error) { + route, err := t.queryFreshDeviceRoute(ctx, deviceID) + if err == nil { + return route, nil + } + if errors.Is(err, ErrDeviceRouteUnavailable) { + return nil, temporal.NewNonRetryableApplicationError( + err.Error(), + "DeviceRouteUnavailable", + err, + ) + } + return nil, err +} + +func (t *thandActivities) queryFreshDeviceRoute( + ctx context.Context, + deviceID string, +) (*models.DeviceConnectionState, error) { + services := t.config.GetServices() + if services == nil || !services.HasTemporal() { + return t.config.GetFreshDeviceRoute(deviceID) + } + + temporalService := services.GetTemporal() + if temporalService == nil || !temporalService.HasClient() { + return t.config.GetFreshDeviceRoute(deviceID) + } + + timeoutCtx, cancel := context.WithTimeout(ctx, deviceRouteRefreshInterval) + defer cancel() + + queryResponse, err := temporalService.GetClient().QueryWorkflowWithOptions(timeoutCtx, &client.QueryWorkflowWithOptionsRequest{ + WorkflowID: models.TemporalDeviceRouteRegistryWorkflowID, + RunID: "", + QueryType: models.TemporalGetDeviceRouteQueryName, + QueryRejectCondition: enums.QUERY_REJECT_CONDITION_NOT_OPEN, + Args: []any{deviceID}, + }) + if err != nil { + return nil, fmt.Errorf("%w: device %q is not connected", ErrDeviceRouteUnavailable, strings.TrimSpace(deviceID)) + } + if queryResponse == nil || queryResponse.QueryResult == nil { + return nil, fmt.Errorf("%w: device %q is not connected", ErrDeviceRouteUnavailable, strings.TrimSpace(deviceID)) + } + + var route models.DeviceConnectionState + if err := queryResponse.QueryResult.Get(&route); err != nil { + return nil, err + } + if !route.Connected || strings.TrimSpace(route.TaskQueue) == "" { + return nil, fmt.Errorf("%w: device %q is not connected", ErrDeviceRouteUnavailable, strings.TrimSpace(deviceID)) + } + + return &route, nil +} diff --git a/internal/config/temporal_workers.go b/internal/config/temporal_workers.go new file mode 100644 index 00000000..48daaf8e --- /dev/null +++ b/internal/config/temporal_workers.go @@ -0,0 +1,99 @@ +package config + +import ( + "github.com/thand-io/agent/internal/models" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" +) + +// temporalWorkerScope narrows a Temporal service view to a specific worker set. +// In server mode we run both the operational worker queue and the shared device +// registry queue, and this wrapper keeps ordinary workflow/activity/provider +// registrations from accidentally landing on the registry worker. +type temporalWorkerScope struct { + base models.TemporalImpl + workerIDs []string +} + +func (t *temporalWorkerScope) Initialize() error { + return t.base.Initialize() +} + +func (t *temporalWorkerScope) Shutdown() error { + return t.base.Shutdown() +} + +func (t *temporalWorkerScope) GetClient() client.Client { + return t.base.GetClient() +} + +func (t *temporalWorkerScope) HasClient() bool { + return t.base.HasClient() +} + +func (t *temporalWorkerScope) GetWorker(identities ...string) worker.Worker { + if len(identities) == 0 { + identities = t.workerIDs + } + return t.base.GetWorker(identities...) +} + +func (t *temporalWorkerScope) HasWorker() bool { + return t.base.HasWorker() +} + +func (t *temporalWorkerScope) GetHostPort() string { + return t.base.GetHostPort() +} + +func (t *temporalWorkerScope) GetNamespace() string { + return t.base.GetNamespace() +} + +func (t *temporalWorkerScope) GetTaskQueue() string { + return t.base.GetTaskQueue() +} + +func (t *temporalWorkerScope) IsVersioningDisabled() bool { + return t.base.IsVersioningDisabled() +} + +// getOperationalTemporalWorker returns the worker that should own normal server +// workflows and activities. Device-registry singletons are registered on a +// separate shared queue. +func (c *Config) getOperationalTemporalWorker() worker.Worker { + temporalService := c.servicesClient.GetTemporal() + if temporalService == nil { + return nil + } + if c.IsServer() { + return temporalService.GetWorker(temporalService.GetTaskQueue()) + } + return temporalService.GetWorker() +} + +// getOperationalTemporalService scopes provider workflow/activity registration +// away from the shared device-registry queue in server mode. +func (c *Config) getOperationalTemporalService() models.TemporalImpl { + temporalService := c.servicesClient.GetTemporal() + if temporalService == nil { + return nil + } + if c.IsServer() { + return &temporalWorkerScope{ + base: temporalService, + workerIDs: []string{temporalService.GetTaskQueue()}, + } + } + return temporalService +} + +// getDeviceRegistryWorker returns the shared worker that owns device registry +// singleton workflows across servers. +func (c *Config) getDeviceRegistryWorker() worker.Worker { + temporalService := c.servicesClient.GetTemporal() + if temporalService == nil { + return nil + } + return temporalService.GetWorker(models.TemporalDeviceRegistryTaskQueue) +} diff --git a/internal/daemon/register.go b/internal/daemon/register.go index 5a5d64a0..aa501bef 100644 --- a/internal/daemon/register.go +++ b/internal/daemon/register.go @@ -37,11 +37,11 @@ func (s *Server) postRegister(c *gin.Context) { cfg := s.GetConfig() c.JSON(http.StatusOK, config.RegistrationResponse{ - Success: true, - Services: &cfg.Services, - //Roles: &cfg.Roles, - //Providers: &cfg.Providers, - //Workflows: &cfg.Workflows, + Success: true, + Services: &cfg.Services, + Roles: &cfg.Roles, + Providers: &cfg.Providers, + Workflows: &cfg.Workflows, }) } diff --git a/internal/daemon/register_test.go b/internal/daemon/register_test.go new file mode 100644 index 00000000..5aec3a4a --- /dev/null +++ b/internal/daemon/register_test.go @@ -0,0 +1,149 @@ +package daemon + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/serverlessworkflow/sdk-go/v3/model" + "github.com/thand-io/agent/internal/config" + "github.com/thand-io/agent/internal/models" +) + +func TestPostRegisterReturnsConfigurationDefinitions(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + + cfg := &config.Config{ + Roles: config.RoleConfig{ + Definitions: map[string]models.Role{ + "viewer": { + Name: "Viewer", + Description: "Read-only access", + Enabled: true, + }, + }, + }, + Workflows: config.WorkflowConfig{ + Definitions: map[string]models.Workflow{ + "approval": { + Name: "Approval", + Description: "Approval workflow", + Enabled: true, + Workflow: &model.Workflow{ + Do: &model.TaskList{}, + }, + }, + }, + }, + Providers: config.ProviderDefinitionsConfig{ + Definitions: map[string]models.ProviderConfig{ + "oauth2-directory": { + Name: "Directory Login", + Description: "Remote OAuth2 provider", + Provider: "oauth2", + Enabled: true, + }, + }, + }, + } + + server := NewServer(cfg) + router := gin.New() + router.POST("/register", server.postRegister) + + body, err := json.Marshal(config.RegistrationRequest{ + Identifier: uuid.New(), + Environment: &models.EnvironmentConfig{ + Name: "device-alpha", + Hostname: "device-alpha.example.test", + Platform: models.Local, + }, + }) + if err != nil { + t.Fatalf("Marshal registration request: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp config.RegistrationResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal response: %v", err) + } + + if resp.Providers == nil || resp.Providers.Definitions["oauth2-directory"].Provider != "oauth2" { + t.Fatalf("response providers missing expected definition: %#v", resp.Providers) + } + if resp.Roles == nil || resp.Roles.Definitions["viewer"].Name != "Viewer" { + t.Fatalf("response roles missing expected definition: %#v", resp.Roles) + } + if resp.Workflows == nil || resp.Workflows.Definitions["approval"].Name != "Approval" { + t.Fatalf("response workflows missing expected definition: %#v", resp.Workflows) + } +} + +func TestPostRegisterOmitsDeviceData(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + + deviceID := uuid.NewString() + cfg := &config.Config{ + Devices: config.DeviceDefinitionsConfig{ + Definitions: map[string]models.Device{ + "workstation-alpha": { + ID: deviceID, + Name: "Workstation Alpha", + Enabled: true, + }, + }, + }, + } + + server := NewServer(cfg) + router := gin.New() + router.POST("/register", server.postRegister) + + body, err := json.Marshal(config.RegistrationRequest{ + Mode: config.ModeAgent, + Identifier: uuid.MustParse(deviceID), + Environment: &models.EnvironmentConfig{ + Name: "alpha", + Hostname: "alpha.example.test", + Platform: models.Local, + }, + }) + if err != nil { + t.Fatalf("Marshal registration request: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal response: %v", err) + } + + if _, found := resp["device"]; found { + t.Fatalf("expected registration response to omit device data, got %#v", resp["device"]) + } +} diff --git a/internal/models/config.go b/internal/models/config.go index f4e651b8..7fb4b27d 100644 --- a/internal/models/config.go +++ b/internal/models/config.go @@ -43,11 +43,15 @@ type ConfigImpl interface { // Tenants GetTenant(name string) (*ProviderTenant, error) + // Devices + GetDevice(deviceID string) (*Device, error) + // Workflows GetWorkflowByName(name string) (*Workflow, error) GetWorkflowFromElevationRequest(elevationRequest *ElevateRequest) (*Workflow, error) // Providers + GetProviderDefinitions() map[string]ProviderConfig GetProviderByName(name string) (Provider, error) GetProvidersByCapability(capability ...ProviderCapability) map[string]Provider GetProvidersByCapabilityWithUser(user *User, capability ...ProviderCapability) map[string]Provider diff --git a/internal/models/device.go b/internal/models/device.go new file mode 100644 index 00000000..cdc92067 --- /dev/null +++ b/internal/models/device.go @@ -0,0 +1,48 @@ +package models + +import ( + "time" +) + +const ( + TemporalResolveFreshDeviceRouteActivityName = "resolve-fresh-device-route" + + TemporalDeviceRegistryTaskQueue = "thand_device_registry" + + TemporalDeviceRouteRegistryWorkflowName = "DeviceRouteRegistryWorkflow" + TemporalDeviceRouteRegistryWorkflowID = "thand-device-route-registry" + TemporalDeviceRouteUpsertSignalName = "upsert-device-route" + TemporalGetDeviceRouteQueryName = "get-device-route" + + TemporalDeviceDefinitionRegistryWorkflowName = "DeviceDefinitionRegistryWorkflow" + TemporalDeviceDefinitionRegistryWorkflowID = "thand-device-definition-registry" + TemporalDeviceDefinitionUpsertSignalName = "upsert-device-definition" + TemporalGetDeviceDefinitionQueryName = "get-device-definition" +) + +const ( + DeviceRouteRefreshInterval = 30 * time.Second + DeviceRouteFreshnessTTL = 2 * time.Minute +) + +type Device struct { + // Device is a first-class execution target managed by the server. + // It is intentionally separate from provider tenants because device routing, + // local policy, and local lease enforcement have different lifecycle needs. + ID string `json:"device_id,omitempty" yaml:"device_id,omitempty" mapstructure:"device_id"` + Name string `json:"name" yaml:"name" mapstructure:"name"` + Description string `json:"description,omitempty" yaml:"description,omitempty" mapstructure:"description"` + Platform string `json:"platform,omitempty" yaml:"platform,omitempty" mapstructure:"platform"` + Enabled bool `json:"enabled" yaml:"enabled" mapstructure:"enabled"` + LocalElevation *DeviceLocalElevationPolicy `json:"local_elevation,omitempty" yaml:"local_elevation,omitempty" mapstructure:"local_elevation"` +} + +type DeviceConnectionState struct { + DeviceID string `json:"device_id,omitempty" yaml:"device_id,omitempty" mapstructure:"device_id"` + TaskQueue string `json:"task_queue,omitempty" yaml:"task_queue,omitempty" mapstructure:"task_queue"` + Name string `json:"name,omitempty" yaml:"name,omitempty" mapstructure:"name"` + Hostname string `json:"hostname,omitempty" yaml:"hostname,omitempty" mapstructure:"hostname"` + Platform string `json:"platform,omitempty" yaml:"platform,omitempty" mapstructure:"platform"` + LastSeenAt time.Time `json:"last_seen_at,omitempty" yaml:"last_seen_at,omitempty" mapstructure:"last_seen_at"` + Connected bool `json:"connected,omitempty" yaml:"connected,omitempty" mapstructure:"connected"` +} diff --git a/internal/models/device_local_elevation.go b/internal/models/device_local_elevation.go new file mode 100644 index 00000000..acebc019 --- /dev/null +++ b/internal/models/device_local_elevation.go @@ -0,0 +1,75 @@ +package models + +import ( + "fmt" + "strings" +) + +type DeviceLocalElevationPolicy struct { + Enabled bool `json:"enabled" yaml:"enabled" mapstructure:"enabled"` + AllowedModes []string `json:"allowed_modes,omitempty" yaml:"allowed_modes,omitempty" mapstructure:"allowed_modes"` + Accounts []DeviceLocalElevationAccount `json:"accounts,omitempty" yaml:"accounts,omitempty" mapstructure:"accounts"` + DeniedUsernames []string `json:"denied_usernames,omitempty" yaml:"denied_usernames,omitempty" mapstructure:"denied_usernames"` + AllowedUIDRanges []string `json:"allowed_uid_ranges,omitempty" yaml:"allowed_uid_ranges,omitempty" mapstructure:"allowed_uid_ranges"` +} + +type DeviceLocalElevationAccount struct { + Identity string `json:"identity,omitempty" yaml:"identity,omitempty" mapstructure:"identity"` + Email string `json:"email,omitempty" yaml:"email,omitempty" mapstructure:"email"` + Username string `json:"username,omitempty" yaml:"username,omitempty" mapstructure:"username"` + LocalUsername string `json:"local_username" yaml:"local_username" mapstructure:"local_username"` +} + +func (p *DeviceLocalElevationPolicy) AllowsMode(mode string) bool { + if p == nil || !p.Enabled { + return false + } + if len(p.AllowedModes) == 0 { + return true + } + for _, allowed := range p.AllowedModes { + if strings.EqualFold(strings.TrimSpace(allowed), strings.TrimSpace(mode)) { + return true + } + } + return false +} + +func (p *DeviceLocalElevationPolicy) ResolveLocalUsername(identityID string, identity *Identity) (string, error) { + if p == nil || !p.Enabled { + return "", fmt.Errorf("local elevation is not enabled for this device") + } + + trimmedIdentityID := strings.TrimSpace(identityID) + var email string + var username string + if identity != nil && identity.User != nil { + email = strings.TrimSpace(identity.User.Email) + username = strings.TrimSpace(identity.User.Username) + } + + for _, account := range p.Accounts { + if account.matches(trimmedIdentityID, email, username) { + localUsername := strings.TrimSpace(account.LocalUsername) + if localUsername == "" { + return "", fmt.Errorf("device account mapping matched without a local username") + } + return localUsername, nil + } + } + + return "", fmt.Errorf("identity %q is not eligible for local sudo on this device", trimmedIdentityID) +} + +func (a DeviceLocalElevationAccount) matches(identityID, email, username string) bool { + if strings.TrimSpace(a.Identity) != "" && strings.EqualFold(strings.TrimSpace(a.Identity), identityID) { + return true + } + if strings.TrimSpace(a.Email) != "" && strings.EqualFold(strings.TrimSpace(a.Email), email) { + return true + } + if strings.TrimSpace(a.Username) != "" && strings.EqualFold(strings.TrimSpace(a.Username), username) { + return true + } + return false +} diff --git a/internal/models/device_registry_queue.go b/internal/models/device_registry_queue.go deleted file mode 100644 index 2bb0326a..00000000 --- a/internal/models/device_registry_queue.go +++ /dev/null @@ -1,3 +0,0 @@ -package models - -const TemporalDeviceRegistryTaskQueue = "thand_device_registry" diff --git a/internal/workflows/manager/workflows.go b/internal/workflows/manager/workflows.go index 0313add4..e02bd020 100644 --- a/internal/workflows/manager/workflows.go +++ b/internal/workflows/manager/workflows.go @@ -32,7 +32,7 @@ func (m *ThandWorkflowManager) registerThandWorkflows() error { return fmt.Errorf("temporal worker not configured") } - worker := temporalService.GetWorker() + worker := temporalService.GetWorker(temporalService.GetTaskQueue()) // Register the primary workflow with Pinned versioning behavior // From 5b09d7957d197708c60ebc523cf9dfb61c587762 Mon Sep 17 00:00:00 2001 From: Michael Weber Date: Wed, 22 Apr 2026 18:17:23 -0500 Subject: [PATCH 07/23] feat(workflows): add execution planning for device-targeted provider tasks --- docs/configuration/workflows/index.md | 26 +- docs/configuration/workflows/tasks.md | 80 ++++- docs/internal/device-model.md | 6 + docs/internal/execution-planning.md | 81 +++++ docs/internal/index.md | 15 + .../config/environment/gcp/workflows.yaml | 2 +- .../environment/kubernetes/workflows.yaml | 2 +- .../config/environment/local/workflows.yaml | 31 ++ internal/config/execution_plan.go | 195 ++++++++++++ internal/config/local_sudo_execution_plan.go | 29 ++ internal/config/providers.go | 4 +- internal/config/temporal.go | 6 + internal/config/temporal_activities.go | 34 ++- internal/models/device.go | 1 + internal/models/elevate.go | 24 +- internal/models/execution_plan.go | 67 +++++ internal/models/provider_rbac.go | 63 ++++ internal/models/provider_workflows.go | 156 ++++------ .../models/provider_workflows_request_test.go | 161 ++++++++++ internal/models/role_clone.go | 15 + internal/models/workflow_elevate_task.go | 22 ++ .../tasks/providers/thand/authorize.go | 202 +++++++------ .../tasks/providers/thand/device_routing.go | 109 +++++++ .../providers/thand/device_routing_test.go | 226 ++++++++++++++ .../workflows/tasks/providers/thand/errors.go | 61 ++++ .../tasks/providers/thand/execution_plan.go | 60 ++++ .../providers/thand/execution_plan_test.go | 284 ++++++++++++++++++ .../workflows/tasks/providers/thand/main.go | 8 + .../workflows/tasks/providers/thand/revoke.go | 283 +++++++++-------- sdk/constants/workflow_elevate_task.go | 13 +- sdk/workflows/models/workflow_task.go | 13 +- 31 files changed, 1926 insertions(+), 353 deletions(-) create mode 100644 docs/internal/execution-planning.md create mode 100644 docs/internal/index.md create mode 100644 internal/config/execution_plan.go create mode 100644 internal/config/local_sudo_execution_plan.go create mode 100644 internal/models/execution_plan.go create mode 100644 internal/models/provider_workflows_request_test.go create mode 100644 internal/models/role_clone.go create mode 100644 internal/workflows/tasks/providers/thand/device_routing.go create mode 100644 internal/workflows/tasks/providers/thand/device_routing_test.go create mode 100644 internal/workflows/tasks/providers/thand/execution_plan.go create mode 100644 internal/workflows/tasks/providers/thand/execution_plan_test.go diff --git a/docs/configuration/workflows/index.md b/docs/configuration/workflows/index.md index a7f4ba85..59947ec6 100644 --- a/docs/configuration/workflows/index.md +++ b/docs/configuration/workflows/index.md @@ -17,9 +17,10 @@ Thand workflows orchestrate the complete lifecycle of access requests: 1. **Validation** - Verify request validity and user permissions 2. **Approval** - Route requests through approval chains with notifications -3. **Authorization** - Grant temporary access to requested resources -4. **Monitoring** - Track usage and detect policy violations -5. **Revocation** - Remove access when complete or violated +3. **Execution Planning** - Compile the final request into an execution plan +4. **Authorization** - Grant temporary access to requested resources +5. **Monitoring** - Track usage and detect policy violations +6. **Revocation** - Remove access when complete or violated Workflows leverage the [Serverless Workflow DSL](https://serverlessworkflow.io/specification/) for standardized process definition while providing custom Thand-specific tasks for access control operations. @@ -57,6 +58,19 @@ workflows: - grant: { thand: authorize } ``` +## Execution Planning + +Access-granting workflows should treat execution planning as part of the standard `authorize` lifecycle. + +`authorize` runs an internal execution-plan activity after validation, approvals, and any other step that can still change the final request. That activity compiles the request into an internal execution plan that `authorize` and `revoke` later consume. + +In practice, the safe patterns are: + +- `validate -> authorize` +- `validate -> approvals -> authorize` + +If a workflow grants access, do not put any request-shaping step after `authorize`, because the execution plan is snapped there. + ## Workflow Structure ### Basic Configuration @@ -132,7 +146,7 @@ workflows: approved: authorize denied: deny-notification then: deny-notification - + # Step 3: Grant access if approved - authorize: thand: authorize @@ -287,6 +301,7 @@ workflows: do: - validate: { thand: validate } - approve: { thand: approvals } + - authorize: { thand: authorize } ``` ## Workflow Patterns @@ -399,12 +414,11 @@ workflows: - when: '${ .duration > "PT4H" }' then: manager-approval - when: '${ .user.department == "security" }' - then: auto-approve + then: authorize default: standard-approval - security-approval: { thand: approvals, then: authorize } - manager-approval: { thand: approvals, then: authorize } - - auto-approve: { thand: authorize, then: end } - standard-approval: { thand: approvals, then: authorize } - authorize: { thand: authorize, then: end } ``` diff --git a/docs/configuration/workflows/tasks.md b/docs/configuration/workflows/tasks.md index ae513c5e..3967e4f8 100644 --- a/docs/configuration/workflows/tasks.md +++ b/docs/configuration/workflows/tasks.md @@ -226,6 +226,59 @@ The approvals task implements the following logic: denied: denied ``` +## Execution Planning + +`authorize` runs an internal execution-plan activity that compiles the current workflow request into the execution plan later used by `authorize` and `revoke`. + +Treat `authorize` as the last request-shaping step before access is granted. It should run only after approvals, form collection, and other workflow logic that might change the final request shape. + +### Execution-Plan Process + +The execution-plan activity: + +1. **Reads** the normalized elevate request from workflow context +2. **Resolves** the provider, identity, device, and local policy data needed for execution +3. **Compiles** one or more provider authorization requests into an internal execution plan +4. **Stores** that execution plan in workflow context/history for later `authorize` and `revoke` steps + +### Requirements and Constraints + +- execution planning is required before provider authorization work starts +- `authorize` should appear immediately after approvals or any other request-shaping step +- `authorize` and `revoke` depend on the recorded execution plan and do not rebuild it later +- `revoke` fails if the execution plan is missing + +### Failure Behavior + +- if execution planning cannot compile the request, the workflow fails before authorization starts +- if `revoke` runs without a recorded execution plan, it fails instead of trying to recover implicitly + +### Examples + +**Validation, Approval, and Authorization** +```yaml +- validate: + thand: validate + then: approvals + +- approvals: + thand: approvals + on: + approved: authorize + denied: denied + then: denied + +- authorize: + thand: authorize +``` + +**Validation and Authorization Without Approval** +```yaml +- validate: + thand: validate + then: authorize +``` + ## authorize The `authorize` task grants temporary access to the requested role and resources. @@ -250,14 +303,17 @@ The `authorize` task grants temporary access to the requested role and resources The authorize task: -1. **Validates** the request is approved (checks workflow context) -2. **Creates** temporary credentials/access across all specified providers -3. **Registers** the session information -4. **Returns** authorization details with timestamps +1. **Builds or reuses** the recorded execution plan from workflow context/history +2. **Validates** the request is approved (checks workflow context) +3. **Creates** temporary credentials/access across all specified providers +4. **Registers** the session information +5. **Returns** authorization details with timestamps ### Authorization Context -The authorize task checks if the request has been approved by looking at the workflow context. If already approved, it returns basic model output with timestamps. +The authorize task checks if the request has been approved by looking at the workflow context. It also snapshots the final request into the execution plan it and `revoke` later consume. + +For any workflow that grants access, treat `authorize` as the first step that may perform provider-side effects. ### Examples @@ -352,8 +408,8 @@ The `revoke` task removes granted access and cleans up temporary credentials. The revoke task: -1. **Validates** the elevate request from workflow context -2. **Iterates** through all providers and identities +1. **Reads** the recorded execution plan from workflow context/history +2. **Uses** the stored authorization request shape to build revocation work 3. **Calls** provider-specific revocation methods 4. **Logs** revocation events 5. **Returns** revocation status with timestamp @@ -576,7 +632,13 @@ Error: authorization failed for user 'alice' role 'admin' ``` **Solution**: Verify the request has been properly approved and the user has permission to request the role. -#### 4. Monitoring Limitations +#### 4. Missing Execution Plan +``` +Error: failed to get execution plan from workflow context +``` +**Solution**: Ensure `thand: revoke` only runs on paths where `thand: authorize` has already executed and recorded the execution plan. + +#### 5. Monitoring Limitations ``` Error: Monitoring is only supported with temporal ``` @@ -591,4 +653,4 @@ logging: level: debug ``` -Check workflow context to understand task inputs and state. \ No newline at end of file +Check workflow context to understand task inputs and state. diff --git a/docs/internal/device-model.md b/docs/internal/device-model.md index ac4e823b..14ae3e18 100644 --- a/docs/internal/device-model.md +++ b/docs/internal/device-model.md @@ -32,6 +32,7 @@ In the current model, a device has: - a stable device ID - human-readable metadata such as `name` and `description` +- optional per-device local-elevation policy Runtime connection state is tracked separately from static device policy. That runtime state currently includes: @@ -62,9 +63,12 @@ The intended architecture is: 2. An agent represents one device, running as a system-level service rather than a per-user helper. 3. `/register` bootstraps config only; running agents publish live route state directly to Temporal. 4. Device-targeted workflows route through that live route only. +5. Device-local capabilities such as local sudo are layered on top of the device substrate. Today the canonical `device_id` is machine-derived. Longer term, device registration should use a stronger enrolled identity, but keep the same `device_id` abstraction boundary. +Device-targeted workflows also rely on a separate execution-planning phase before authorization. That workflow-level contract is documented in [Execution Planning](/internal/execution-planning.html). + ## Phase 1: What Is Implemented Now Phase 1 establishes the basic device substrate without yet solving strong device identity. @@ -127,6 +131,7 @@ Other gaps: - no secure enrollment story yet - shared device registries are still internal Temporal workflows rather than a broader device control-plane service - no independent privileged helper transport yet on Linux or Windows +- macOS now has an initial native privilege-services split, with an app-managed login item, broker daemon, and brokerctl bridge for timed sudoers grants - no explicit multi-agent-per-device design, because the current assumption is one system agent per device ## Future Phases @@ -147,6 +152,7 @@ Future work should cover at least: ### Privileged local helper +- macOS LaunchDaemon broker plus per-user notification helper - OS-native trust checks between the unprivileged agent, broker, and notifier - narrow local lease/enforcer contract with persisted expiry and restart reconciliation - future Linux and Windows helpers that match the same broker client abstraction diff --git a/docs/internal/execution-planning.md b/docs/internal/execution-planning.md new file mode 100644 index 00000000..564b0c30 --- /dev/null +++ b/docs/internal/execution-planning.md @@ -0,0 +1,81 @@ +--- +layout: default +title: Execution Planning +parent: Internal +nav_order: 2 +--- + +# Execution Planning + +This document describes the internal `execution_plan` contract and the rules that `authorize` and `revoke` rely on. + +## Why Execution Planning Exists + +The workflow engine needs a deterministic point where a request stops being user-facing intent and becomes execution-ready work. + +That planning step now happens inside a single Temporal activity invoked by `authorize`. The activity: + +- reads the normalized request after validation and approvals are complete +- resolves the provider, identity, and device data needed for execution +- materializes provider-native authorization requests +- stores the resulting `execution_plan` in workflow context/history + +This keeps mutable lookups out of workflow code while still letting `authorize` and `revoke` stay generic. + +Device-local request shaping is handled by internal execution-plan decorators. That keeps action-specific logic, such as local sudo device-policy enrichment, together without teaching the Temporal activity about individual request types. + +## What the Execution Plan Contains + +The execution plan is an immutable execution snapshot for the rest of the workflow. It contains one or more canonical authorization requests, already shaped for later provider execution. + +Each entry includes: + +- a stable `EntryID` +- the provider name +- the canonical `device_id` used for routing +- a fully materialized provider authorization request + +The plan is an internal contract. It is not a user-facing API and should not be treated as a public workflow output. + +## Execution Contract + +The execution contract is: + +1. workflow input and approvals produce the final request intent +2. `authorize` calls the execution-plan activity once +3. that activity writes `execution_plan` into workflow context/history +4. `authorize` consumes the recorded plan +5. `revoke` later consumes the same recorded plan + +`authorize` is the only task that may create the plan. `revoke` must reuse the recorded plan and fails clearly if it is missing. + +## Routing Rule + +Routing stays intentionally simple: + +- if `DeviceID == ""`, execution stays on the parent workflow queue +- if `DeviceID != ""`, execution is device-scoped and dispatch waits for a fresh route to that device + +Execution planning is responsible for deciding whether a request becomes device-scoped by setting `DeviceID` on the stored authorization request. + +`authorize` does not need to know why the request is device-scoped. It only routes based on `DeviceID`. + +## Ordering Requirements + +Execution planning is required before any access-granting provider work starts, but it is not a user-facing workflow task anymore. + +The intended ordering is: + +- `validate -> authorize` +- `validate -> approvals -> authorize` + +Put `authorize` after approvals, forms, or any other step that can still change the final request shape. That guarantees the execution-plan activity snapshots the final request, not an intermediate one. + +## Failure Semantics and Constraints + +- execution planning is the snapshot point for request shaping +- if execution planning fails, the workflow fails before access is granted +- if policy changes after the plan is recorded, the in-flight workflow keeps using the recorded snapshot +- `revoke` depends on the previously recorded request shape and does not attempt to infer it later + +This separation keeps the execution model predictable and makes failure points easier to reason about. diff --git a/docs/internal/index.md b/docs/internal/index.md new file mode 100644 index 00000000..65c1a9e9 --- /dev/null +++ b/docs/internal/index.md @@ -0,0 +1,15 @@ +--- +layout: default +title: Internal +nav_order: 7 +has_children: true +--- + +# Internal Documentation +{: .no_toc } + +Architecture notes, ADRs, and implementation guidance for maintainers. + +- [Device Model](/internal/device-model.html) +- [Execution Planning](/internal/execution-planning.html) +- [ADR: Device Routing Phase 1](/internal/adr-device-routing-phase-1.html) diff --git a/internal/config/environment/gcp/workflows.yaml b/internal/config/environment/gcp/workflows.yaml index e7ff594f..c2f037bd 100644 --- a/internal/config/environment/gcp/workflows.yaml +++ b/internal/config/environment/gcp/workflows.yaml @@ -31,4 +31,4 @@ workflows: - revoke: thand: revoke then: end - \ No newline at end of file + diff --git a/internal/config/environment/kubernetes/workflows.yaml b/internal/config/environment/kubernetes/workflows.yaml index 470c5e89..5fbe290c 100644 --- a/internal/config/environment/kubernetes/workflows.yaml +++ b/internal/config/environment/kubernetes/workflows.yaml @@ -31,4 +31,4 @@ workflows: - revoke: thand: revoke then: end - \ No newline at end of file + diff --git a/internal/config/environment/local/workflows.yaml b/internal/config/environment/local/workflows.yaml index e00b2eb0..07c4280d 100644 --- a/internal/config/environment/local/workflows.yaml +++ b/internal/config/environment/local/workflows.yaml @@ -32,3 +32,34 @@ workflows: thand: revoke then: end + local_sudo_timed_elevation: + name: "Local Sudo Timed Elevation" + description: Time-bound local sudo access on a server-managed device + authentication: default + enabled: true + workflow: + document: + dsl: "1.0.0-alpha5" + namespace: "thand" + name: "local-sudo-timed-elevation" + version: "1.0.0" + do: + - validate: + thand: validate + with: + validator: static + then: authorize + - authorize: + thand: authorize + with: + revocation: revoke + then: monitor + - monitor: + thand: monitor + with: + monitor: basic + threshold: 100 + then: revoke + - revoke: + thand: revoke + then: end diff --git a/internal/config/execution_plan.go b/internal/config/execution_plan.go new file mode 100644 index 00000000..3e84daa8 --- /dev/null +++ b/internal/config/execution_plan.go @@ -0,0 +1,195 @@ +package config + +import ( + "fmt" + "strings" + + "github.com/thand-io/agent/internal/models" +) + +type executionPlanBuildOptions struct { + LookupDeviceDefinition func(deviceID string) (*models.Device, error) + Decorators []executionPlanDecorator +} + +// executionPlanDecorator lets device-local or provider-specific request shaping +// stay close to the feature that needs it instead of branching inside the +// Temporal activity that drives planning. +type executionPlanDecorator interface { + Applies(elevateRequest *models.ElevateRequestInternal) bool + // Decorate runs before EntryID creation and provider request materialization. + // Use it to populate request metadata and routing fields that should + // contribute to the stable execution-plan entry identity. + Decorate( + cfg models.ConfigImpl, + req *models.WorkflowRoleRequest, + elevateRequest *models.ElevateRequestInternal, + opts executionPlanBuildOptions, + ) error + // Finalize runs after EntryID creation. Use it for metadata that must depend + // on the stable entry identity itself, such as broker grant IDs, without + // feeding that generated value back into the EntryID calculation. + Finalize( + req *models.WorkflowRoleRequest, + elevateRequest *models.ElevateRequestInternal, + entryID string, + ) error +} + +func BuildExecutionPlan( + cfg models.ConfigImpl, + workflowID string, + elevateRequest *models.ElevateRequestInternal, +) (*models.ExecutionPlan, error) { + return BuildExecutionPlanWithOptions(cfg, workflowID, elevateRequest, executionPlanBuildOptions{}) +} + +func BuildExecutionPlanWithOptions( + cfg models.ConfigImpl, + workflowID string, + elevateRequest *models.ElevateRequestInternal, + opts executionPlanBuildOptions, +) (*models.ExecutionPlan, error) { + if elevateRequest == nil { + return nil, fmt.Errorf("elevate request is required for execution planning") + } + if len(elevateRequest.Providers) == 0 { + return nil, fmt.Errorf("no providers specified for authorization") + } + if len(elevateRequest.Identities) == 0 { + return nil, fmt.Errorf("no identities specified for authorization") + } + + opts = opts.withDefaults(cfg) + + duration, err := elevateRequest.AsDuration() + if err != nil { + return nil, fmt.Errorf("failed to get duration: %w", err) + } + + workflowName := strings.TrimSpace(elevateRequest.GetWorkflow()) + if workflowName == "" { + return nil, fmt.Errorf("workflow name is required for execution planning") + } + + tenants := elevateRequest.Tenants + if len(tenants) == 0 { + tenants = []string{""} + } + + plan := &models.ExecutionPlan{WorkflowName: workflowName} + + for _, providerName := range elevateRequest.Providers { + providerName = strings.TrimSpace(providerName) + if providerName == "" { + return nil, fmt.Errorf("execution plan entry is missing provider name") + } + + provider, err := cfg.GetProviderByName(providerName) + if err != nil { + return nil, fmt.Errorf("failed to get provider %q: %w", providerName, err) + } + + for _, identityID := range elevateRequest.Identities { + resolvedIdentity := resolveIdentitySnapshot(cfg, identityID) + + for _, tenantID := range tenants { + workflowReq := &models.WorkflowRoleRequest{ + WorkflowID: workflowID, + Identity: identityID, + ResolvedIdentity: resolvedIdentity, + Role: elevateRequest.Role, + Duration: &duration, + Tenant: tenantID, + } + + if err := applyExecutionPlanDecorators(cfg, workflowReq, elevateRequest, opts); err != nil { + return nil, err + } + + entryID := models.CreateExecutionPlanEntryID(workflowID, providerName, workflowReq) + if err := finalizeExecutionPlanDecorators(workflowReq, elevateRequest, entryID, opts); err != nil { + return nil, err + } + + authorizeRequest, err := models.CreateAuthorizeRoleRequest(cfg, provider, workflowReq) + if err != nil { + return nil, fmt.Errorf("failed to create authorize role request for provider %q and identity %q: %w", providerName, identityID, err) + } + + plan.Entries = append(plan.Entries, models.ExecutionPlanEntry{ + EntryID: entryID, + ProviderName: providerName, + DeviceID: workflowReq.DeviceID, + AuthorizeRequest: authorizeRequest, + }) + } + } + } + + if !plan.IsValid() { + return nil, fmt.Errorf("execution plan did not contain any entries") + } + + return plan, nil +} + +func (opts executionPlanBuildOptions) withDefaults(cfg models.ConfigImpl) executionPlanBuildOptions { + if opts.LookupDeviceDefinition == nil { + opts.LookupDeviceDefinition = cfg.GetDevice + } + if opts.Decorators == nil { + opts.Decorators = []executionPlanDecorator{ + localSudoExecutionPlanDecorator{}, + } + } + return opts +} + +func applyExecutionPlanDecorators( + cfg models.ConfigImpl, + req *models.WorkflowRoleRequest, + elevateRequest *models.ElevateRequestInternal, + opts executionPlanBuildOptions, +) error { + if req == nil { + return fmt.Errorf("workflow role request is required for execution planning") + } + + for _, decorator := range opts.Decorators { + if decorator == nil || !decorator.Applies(elevateRequest) { + continue + } + if err := decorator.Decorate(cfg, req, elevateRequest, opts); err != nil { + return err + } + } + + return nil +} + +func finalizeExecutionPlanDecorators( + req *models.WorkflowRoleRequest, + elevateRequest *models.ElevateRequestInternal, + entryID string, + opts executionPlanBuildOptions, +) error { + for _, decorator := range opts.Decorators { + if decorator == nil || !decorator.Applies(elevateRequest) { + continue + } + if err := decorator.Finalize(req, elevateRequest, entryID); err != nil { + return err + } + } + + return nil +} + +func resolveIdentitySnapshot(cfg models.ConfigImpl, identityID string) *models.Identity { + identityResult, err := cfg.GetIdentity(identityID) + if err != nil || identityResult == nil || identityResult.User == nil { + return nil + } + return identityResult +} diff --git a/internal/config/local_sudo_execution_plan.go b/internal/config/local_sudo_execution_plan.go new file mode 100644 index 00000000..8bcfb41c --- /dev/null +++ b/internal/config/local_sudo_execution_plan.go @@ -0,0 +1,29 @@ +package config + +import "github.com/thand-io/agent/internal/models" + +// localSudoExecutionPlanDecorator is a no-op until local sudo request shaping +// lands. Keeping the hook in the execution-plan layer lets later commits add +// the feature without reshaping this baseline. +type localSudoExecutionPlanDecorator struct{} + +func (localSudoExecutionPlanDecorator) Applies(*models.ElevateRequestInternal) bool { + return false +} + +func (localSudoExecutionPlanDecorator) Decorate( + models.ConfigImpl, + *models.WorkflowRoleRequest, + *models.ElevateRequestInternal, + executionPlanBuildOptions, +) error { + return nil +} + +func (localSudoExecutionPlanDecorator) Finalize( + *models.WorkflowRoleRequest, + *models.ElevateRequestInternal, + string, +) error { + return nil +} diff --git a/internal/config/providers.go b/internal/config/providers.go index 7cd025bd..19070088 100644 --- a/internal/config/providers.go +++ b/internal/config/providers.go @@ -319,7 +319,7 @@ func (c *Config) registerProviderTemporalBindings(providerResult models.Provider models.TemporalAuthorizeRoleWorkflowName, ) worker.RegisterWorkflowWithOptions( - models.CreateProviderAuthorizeRoleWorkflow(c, providerResult), + models.CreateProviderAuthorizeRoleWorkflow(providerResult), workflow.RegisterOptions{ Name: authWorkflowName, VersioningBehavior: workflow.VersioningBehaviorPinned, @@ -331,7 +331,7 @@ func (c *Config) registerProviderTemporalBindings(providerResult models.Provider models.TemporalRevokeRoleWorkflowName, ) worker.RegisterWorkflowWithOptions( - models.CreateProviderRevokeRoleWorkflow(c, providerResult), + models.CreateProviderRevokeRoleWorkflow(providerResult), workflow.RegisterOptions{ Name: revokeWorkflowName, VersioningBehavior: workflow.VersioningBehaviorPinned, diff --git a/internal/config/temporal.go b/internal/config/temporal.go index 59bacfee..9b302720 100644 --- a/internal/config/temporal.go +++ b/internal/config/temporal.go @@ -91,6 +91,12 @@ func (c *Config) registerTemporalActivities() error { Name: models.TemporalResolveFreshDeviceRouteActivityName, }, ) + temporalWorker.RegisterActivityWithOptions( + thandActivities.BuildExecutionPlan, + activity.RegisterOptions{ + Name: models.TemporalBuildExecutionPlanActivityName, + }, + ) return nil diff --git a/internal/config/temporal_activities.go b/internal/config/temporal_activities.go index 133f3cd1..c755ab32 100644 --- a/internal/config/temporal_activities.go +++ b/internal/config/temporal_activities.go @@ -16,7 +16,8 @@ import ( ) type thandActivities struct { - config *Config + config *Config + lookupDeviceDefinition func(ctx context.Context, deviceID string) (*models.Device, error) } // PatchProviderUpstreamDummy is a no-op activity for thand server/agents that are not @@ -107,6 +108,37 @@ func (t *thandActivities) ResolveFreshDeviceRoute( return nil, err } +func (t *thandActivities) BuildExecutionPlan( + ctx context.Context, + req models.ExecutionPlanRequest, +) (*models.ExecutionPlan, error) { + if req.ElevateRequest == nil { + return nil, temporal.NewNonRetryableApplicationError( + "elevate request is required for execution planning", + "ExecutionPlanInvalid", + nil, + ) + } + + plan, err := BuildExecutionPlanWithOptions(t.config, req.WorkflowID, req.ElevateRequest, executionPlanBuildOptions{ + LookupDeviceDefinition: func(deviceID string) (*models.Device, error) { + if t.lookupDeviceDefinition != nil { + return t.lookupDeviceDefinition(ctx, deviceID) + } + return t.config.querySharedDeviceDefinition(ctx, deviceID) + }, + }) + if err == nil { + return plan, nil + } + + return nil, temporal.NewNonRetryableApplicationError( + err.Error(), + "ExecutionPlanInvalid", + err, + ) +} + func (t *thandActivities) queryFreshDeviceRoute( ctx context.Context, deviceID string, diff --git a/internal/models/device.go b/internal/models/device.go index cdc92067..6990476a 100644 --- a/internal/models/device.go +++ b/internal/models/device.go @@ -6,6 +6,7 @@ import ( const ( TemporalResolveFreshDeviceRouteActivityName = "resolve-fresh-device-route" + TemporalBuildExecutionPlanActivityName = "build-execution-plan" TemporalDeviceRegistryTaskQueue = "thand_device_registry" diff --git a/internal/models/elevate.go b/internal/models/elevate.go index b1d5837b..f54b6162 100644 --- a/internal/models/elevate.go +++ b/internal/models/elevate.go @@ -17,6 +17,7 @@ type ElevateStaticRequest struct { Role string `json:"role" form:"role"` Provider string `json:"provider" form:"provider"` Workflow string `json:"workflow" form:"workflow"` + Device string `json:"device,omitempty" form:"device,omitempty"` Reason string `json:"reason" form:"reason" binding:"required"` Duration string `json:"duration,omitempty" form:"duration,omitempty"` // Duration in ISO 8601 format Identities []string `json:"identities,omitempty" form:"identities,omitempty"` // Optional identities to elevate, if empty the requesting user is used @@ -31,6 +32,7 @@ func (r *ElevateStaticRequest) GetUrlParams() url.Values { "reason": {r.Reason}, "role": {r.Role}, "workflow": {r.Workflow}, + "device": {r.Device}, "duration": {r.Duration}, "provider": {r.Provider}, "identities": {strings.Join(r.Identities, ",")}, @@ -56,15 +58,17 @@ type ElevateResponse struct { } type ElevateRequest struct { - Role *Role `json:"role"` - Providers []string `json:"providers"` // A role can be applied to multiple providers - Authenticator string `json:"authenticator"` // Which provider to use for authentication - Workflow string `json:"workflow"` - Reason string `json:"reason"` - Duration string `json:"duration,omitempty"` // Duration in ISO 8601 format - Identities []string `json:"identities,omitempty"` // Optional identities to elevate, if empty the requesting user is used - Tenants []string `json:"tenants,omitempty"` // Optional tenant IDs for multi-account providers - Session *LocalSession `json:"session,omitempty"` + Role *Role `json:"role"` + Providers []string `json:"providers"` // A role can be applied to multiple providers + Authenticator string `json:"authenticator"` // Which provider to use for authentication + Workflow string `json:"workflow"` + Device string `json:"device,omitempty"` // Canonical device_id for local execution + Reason string `json:"reason"` + Duration string `json:"duration,omitempty"` // Duration in ISO 8601 format + Identities []string `json:"identities,omitempty"` // Optional identities to elevate, if empty the requesting user is used + Tenants []string `json:"tenants,omitempty"` // Optional tenant IDs for multi-account providers + Metadata map[string]any `json:"metadata,omitempty"` // Provider/workflow-specific metadata + Session *LocalSession `json:"session,omitempty"` } func (e *ElevateRequest) IsValid() bool { @@ -81,10 +85,12 @@ func (e *ElevateRequest) AsMap() map[string]any { "providers": e.Providers, "authenticator": e.Authenticator, "workflow": e.Workflow, + "device": e.Device, "reason": e.Reason, "duration": e.Duration, "identities": e.Identities, "tenants": e.Tenants, + "metadata": e.Metadata, } } diff --git a/internal/models/execution_plan.go b/internal/models/execution_plan.go new file mode 100644 index 00000000..04014683 --- /dev/null +++ b/internal/models/execution_plan.go @@ -0,0 +1,67 @@ +package models + +type ExecutionPlan struct { + WorkflowName string `json:"workflow_name,omitempty"` + Entries []ExecutionPlanEntry `json:"entries,omitempty"` +} + +func (p *ExecutionPlan) IsValid() bool { + return p != nil && len(p.Entries) > 0 +} + +type ExecutionPlanEntry struct { + EntryID string `json:"entry_id,omitempty"` + ProviderName string `json:"provider_name,omitempty"` + DeviceID string `json:"device_id,omitempty"` + AuthorizeRequest *AuthorizeRoleRequest `json:"authorize_request,omitempty"` +} + +type ExecutionPlanRequest struct { + WorkflowID string `json:"workflow_id,omitempty"` + ElevateRequest *ElevateRequestInternal `json:"elevate_request,omitempty"` +} + +func CloneWorkflowRoleRequest(req *WorkflowRoleRequest) *WorkflowRoleRequest { + if req == nil { + return nil + } + + clone := *req + if req.Role != nil { + clone.Role = CloneRole(req.Role) + } + if req.Duration != nil { + duration := *req.Duration + clone.Duration = &duration + } + if req.Metadata != nil { + clone.Metadata = make(map[string]any, len(req.Metadata)) + for key, value := range req.Metadata { + clone.Metadata[key] = value + } + } + + return &clone +} + +func CloneExecutionPlan(plan *ExecutionPlan) *ExecutionPlan { + if plan == nil { + return nil + } + + clone := &ExecutionPlan{ + WorkflowName: plan.WorkflowName, + Entries: make([]ExecutionPlanEntry, 0, len(plan.Entries)), + } + + for _, entry := range plan.Entries { + clone.Entries = append(clone.Entries, ExecutionPlanEntry{ + EntryID: entry.EntryID, + ProviderName: entry.ProviderName, + DeviceID: entry.DeviceID, + AuthorizeRequest: CloneAuthorizeRoleRequest(entry.AuthorizeRequest), + }) + } + + return clone +} diff --git a/internal/models/provider_rbac.go b/internal/models/provider_rbac.go index 4aea56b0..e71c3a7e 100644 --- a/internal/models/provider_rbac.go +++ b/internal/models/provider_rbac.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "maps" "path" "slices" "sort" @@ -19,6 +20,7 @@ type AuthorizeRoleRequest struct { Identity *Identity `json:"identity,omitempty"` // User or group identifier Role *CompositeRole `json:"role,omitempty"` Duration *time.Duration `json:"duration,omitempty"` // Optional duration for temporary access + Metadata map[string]any `json:"metadata,omitempty"` // Provider-specific workflow metadata } func (r *AuthorizeRoleRequest) IsValid() bool { @@ -48,6 +50,67 @@ func (r *AuthorizeRoleRequest) HasTenant() bool { return r.Tenant != nil && len(r.Tenant.ID) > 0 } +func CloneAuthorizeRoleRequest(req *AuthorizeRoleRequest) *AuthorizeRoleRequest { + if req == nil { + return nil + } + + clone := *req + if req.Tenant != nil { + tenant := *req.Tenant + clone.Tenant = &tenant + } + if req.Identity != nil { + identity := *req.Identity + clone.Identity = &identity + } + if req.Role != nil { + role := *req.Role + clone.Role = &role + } + if req.Duration != nil { + duration := *req.Duration + clone.Duration = &duration + } + if req.Metadata != nil { + clone.Metadata = maps.Clone(req.Metadata) + } + + return &clone +} + +func CloneRevokeRoleRequest(req *RevokeRoleRequest) *RevokeRoleRequest { + if req == nil { + return nil + } + + clone := &RevokeRoleRequest{ + AuthorizeRoleRequest: CloneAuthorizeRoleRequest(req.AuthorizeRoleRequest), + AuthorizeRoleResponse: req.AuthorizeRoleResponse, + } + if req.AuthorizeRoleResponse != nil { + response := *req.AuthorizeRoleResponse + if req.AuthorizeRoleResponse.Roles != nil { + response.Roles = append([]string(nil), req.AuthorizeRoleResponse.Roles...) + } + if req.AuthorizeRoleResponse.Permissions != nil { + response.Permissions = append([]string(nil), req.AuthorizeRoleResponse.Permissions...) + } + if req.AuthorizeRoleResponse.Groups != nil { + response.Groups = append([]string(nil), req.AuthorizeRoleResponse.Groups...) + } + if req.AuthorizeRoleResponse.Resources != nil { + response.Resources = append([]string(nil), req.AuthorizeRoleResponse.Resources...) + } + if req.AuthorizeRoleResponse.Metadata != nil { + response.Metadata = maps.Clone(req.AuthorizeRoleResponse.Metadata) + } + clone.AuthorizeRoleResponse = &response + } + + return clone +} + type AuthorizeRoleResponse struct { UserId string `json:"user_id,omitempty"` // The ID of the user the role was authorized for Roles []string `json:"roles,omitempty"` // The roles that were authorized diff --git a/internal/models/provider_workflows.go b/internal/models/provider_workflows.go index a8caa4fa..e59f6b4e 100644 --- a/internal/models/provider_workflows.go +++ b/internal/models/provider_workflows.go @@ -46,11 +46,10 @@ func CreateTemporalWorkflowIdentifier(workflowName string) string { return strings.ToLower(fmt.Sprintf("%s-%s", common.GetClientIdentifier(), workflowName)) } -// CreateChildWorkflowID generates a unique child workflow ID by hashing a composite -// identifier built from provider, role, identity, tenant, and parent workflow ID. -// This ensures uniqueness across different identities/tenants requesting the same role. -// Format: parentWorkflowID_operation_hash -func CreateChildWorkflowID(parentWorkflowID, operation, provider string, req *WorkflowRoleRequest) string { +// CreateExecutionPlanEntryID generates a stable identifier for a single provider +// execution plan entry. It is derived from the same composite request shape used +// for child workflow IDs so retries map back to the same logical authorization. +func CreateExecutionPlanEntryID(parentWorkflowID, provider string, req *WorkflowRoleRequest) string { // Build composite identifier similar to CompositeRoleWorkflowIdentifier // but using the data available in WorkflowRoleRequest parts := []string{ @@ -66,6 +65,10 @@ func CreateChildWorkflowID(parentWorkflowID, operation, provider string, req *Wo parts = append(parts, req.Identity) + if len(req.DeviceID) > 0 { + parts = append(parts, req.DeviceID) + } + if len(req.Tenant) > 0 { parts = append(parts, req.Tenant) } @@ -77,7 +80,19 @@ func CreateChildWorkflowID(parentWorkflowID, operation, provider string, req *Wo hash := sha256.Sum256([]byte(composite)) hashStr := hex.EncodeToString(hash[:])[:12] // Use first 12 chars (48 bits) - return fmt.Sprintf("%s_%s_%s", parentWorkflowID, operation, hashStr) + return hashStr +} + +// CreateChildWorkflowID generates a unique child workflow ID by hashing a composite +// identifier built from provider, role, identity, tenant, and parent workflow ID. +// This ensures uniqueness across different identities/tenants requesting the same role. +// Format: parentWorkflowID_operation_hash +func CreateChildWorkflowID(parentWorkflowID, operation, provider string, req *WorkflowRoleRequest) string { + return CreateChildWorkflowIDForEntry(parentWorkflowID, operation, CreateExecutionPlanEntryID(parentWorkflowID, provider, req)) +} + +func CreateChildWorkflowIDForEntry(parentWorkflowID, operation, entryID string) string { + return fmt.Sprintf("%s_%s_%s", parentWorkflowID, operation, entryID) } // runSyncLoop runs a single synchronization capability inside a Temporal workflow @@ -285,11 +300,14 @@ func CreateProviderSynchronizeWorkflow(provider Provider) func(workflow.Context, } type WorkflowRoleRequest struct { - WorkflowID string `json:"workflow_id"` // ID of the workflow for which the role is being authorized - Tenant string `json:"tenant,omitempty"` // Optional tenant ID for multi-account providers - Identity string `json:"identity"` // User or group identifier - Role *Role `json:"role"` - Duration *time.Duration `json:"duration,omitempty"` // Optional duration for temporary access + WorkflowID string `json:"workflow_id"` // ID of the workflow for which the role is being authorized + DeviceID string `json:"device_id,omitempty"` + Tenant string `json:"tenant,omitempty"` // Optional tenant ID for multi-account providers + Identity string `json:"identity"` // User or group identifier + ResolvedIdentity *Identity `json:"resolved_identity,omitempty"` + Role *Role `json:"role"` + Duration *time.Duration `json:"duration,omitempty"` // Optional duration for temporary access + Metadata map[string]any `json:"metadata,omitempty"` } // IsValid checks if any of the fields are nil @@ -314,114 +332,44 @@ func (r *WorkflowRoleRequest) GetDuration() *time.Duration { return r.Duration } -// authorizeRoleRequestSideEffect is used to carry the result of -// CreateAuthorizeRoleRequest across a workflow.SideEffect boundary so that -// non-deterministic operations (config lookups, UUID generation) are isolated -// from workflow replay. -type authorizeRoleRequestSideEffect struct { - Request *AuthorizeRoleRequest `json:"request"` - Err string `json:"error"` -} - // CreateProviderAuthorizeRoleWorkflow returns a workflow function that captures the -// live provider instance via closure. The child workflow receives the Temporal -// workflow.Context, constructs a WorkflowTaskSupport with it, and delegates to -// provider.AuthorizeRole — allowing the provider to dispatch activities, use -// workflow.Go, and manage state just as it does in the primary workflow. -// Careful: The workflow function returned by this method will be executed as a Temporal workflow, so it must be deterministic and should not perform any non-deterministic operations (like generating random numbers or accessing the current time) directly in the workflow code. Any such operations should be performed within activities or isolated using workflow.SideEffect to ensure correct behavior during workflow replay. -func CreateProviderAuthorizeRoleWorkflow(cfg ConfigImpl, provider Provider) func(workflow.Context, WorkflowRoleRequest) (*AuthorizeRoleResponse, error) { - return func(ctx workflow.Context, req WorkflowRoleRequest) (*AuthorizeRoleResponse, error) { +// live provider instance via closure. The child workflow receives a fully +// materialized provider request so it can delegate directly to provider +// activities without any workflow-side config lookups. +func CreateProviderAuthorizeRoleWorkflow(provider Provider) func(workflow.Context, AuthorizeRoleRequest) (*AuthorizeRoleResponse, error) { + return func(ctx workflow.Context, req AuthorizeRoleRequest) (*AuthorizeRoleResponse, error) { log := workflow.GetLogger(ctx) log.Info("Starting authorize role workflow", "provider", provider.GetIdentifier()) - // Wrap in a SideEffect so that the non-deterministic operations inside - // CreateAuthorizeRoleRequest (config/identity/tenant lookups, UUID generation - // for the composite role identifier) are executed only on the first run and - // their result is recorded in the workflow event history. On replay, Temporal - // replays the recorded value instead of re-executing the function, keeping - // workflow execution deterministic. - // - // Note: CompositeRole has a custom UnmarshalJSON to ensure UUID, Composite, - // and Providers fields survive JSON serialization through Temporal's data converter. - encodedReq := workflow.SideEffect(ctx, func(ctx workflow.Context) any { - result, err := CreateAuthorizeRoleRequest(cfg, provider, &req) - if err != nil { - return authorizeRoleRequestSideEffect{Err: err.Error()} - } - return authorizeRoleRequestSideEffect{Request: result} - }) - - var se authorizeRoleRequestSideEffect - if err := encodedReq.Get(&se); err != nil { - log.Error("Failed to decode authorize role request side effect", "error", err) - return nil, err - } - if se.Err != "" { - log.Error("Failed to create authorize role request", "error", se.Err) - return nil, fmt.Errorf("%s", se.Err) - } - log.Debug("Constructed authorize role request, invoking provider", "provider", provider.GetIdentifier(), - "authorizeReq", se.Request, + "authorizeReq", req, ) - return provider.AuthorizeRole(ctx, se.Request) + return provider.AuthorizeRole(ctx, &req) } } type WorkflowRevokeRoleRequest struct { - RevokeRoleRequest *WorkflowRoleRequest + RevokeRoleRequest *RevokeRoleRequest AuthorizeRoleResponse *AuthorizeRoleResponse } // CreateProviderRevokeRoleWorkflow returns a workflow function that captures the // live provider instance via closure for revocation operations. -// Careful: The revoke workflow may need to reconstruct the original AuthorizeRoleRequest -// and must handle it deterministically within a workflow.SideEffect. -func CreateProviderRevokeRoleWorkflow(cfg ConfigImpl, provider Provider) func(workflow.Context, WorkflowRevokeRoleRequest) (*RevokeRoleResponse, error) { +func CreateProviderRevokeRoleWorkflow(provider Provider) func(workflow.Context, WorkflowRevokeRoleRequest) (*RevokeRoleResponse, error) { return func(ctx workflow.Context, req WorkflowRevokeRoleRequest) (*RevokeRoleResponse, error) { log := workflow.GetLogger(ctx) log.Info("Starting revoke role workflow", "provider", provider.GetIdentifier()) - var authReq *AuthorizeRoleRequest - if req.RevokeRoleRequest != nil { - // Same reasoning as in CreateProviderAuthorizeRoleWorkflow: wrap in a - // SideEffect to isolate non-deterministic config lookups and UUID - // generation from replay. - encodedReq := workflow.SideEffect(ctx, func(ctx workflow.Context) any { - result, err := CreateAuthorizeRoleRequest(cfg, provider, req.RevokeRoleRequest) - if err != nil { - return authorizeRoleRequestSideEffect{Err: err.Error()} - } - return authorizeRoleRequestSideEffect{Request: result} - }) - - var se authorizeRoleRequestSideEffect - if err := encodedReq.Get(&se); err != nil { - log.Error("Failed to decode revoke role request side effect", "error", err) - return nil, err - } - if se.Err != "" { - log.Error("Failed to create revoke role request", "error", se.Err) - return nil, fmt.Errorf("%s", se.Err) - } - authReq = se.Request - } - - revokeReq := &RevokeRoleRequest{ - AuthorizeRoleRequest: authReq, - AuthorizeRoleResponse: req.AuthorizeRoleResponse, - } - log.Debug("Constructed revoke role request, invoking provider", "provider", provider.GetIdentifier(), - "revokeReq", revokeReq, + "revokeReq", req.RevokeRoleRequest, ) - return provider.RevokeRole(ctx, revokeReq) + return provider.RevokeRole(ctx, req.RevokeRoleRequest) } } @@ -434,13 +382,17 @@ func CreateAuthorizeRoleRequest( ) (*AuthorizeRoleRequest, error) { // Get the user identity from the request - identity, err := cfg.GetIdentity(req.Identity) - if err != nil { - identity = &Identity{ - ID: req.Identity, - User: &User{ - Email: req.Identity, - }, + identity := req.ResolvedIdentity + var err error + if identity == nil { + identity, err = cfg.GetIdentity(req.Identity) + if err != nil { + identity = &Identity{ + ID: req.Identity, + User: &User{ + Email: req.Identity, + }, + } } } @@ -481,11 +433,13 @@ func CreateAuthorizeRoleRequest( Tenant: tenant, Role: compositeRole, Duration: req.Duration, + Metadata: maps.Clone(req.Metadata), }, nil } -// validateRoleAndBuildOutput validates the role and builds the initial model output -// Careful: This function is called within a workflow.SideEffect, so it must be deterministic and cannot perform any Temporal operations (activities, child workflows, timers) or access non-deterministic data (current time, random numbers). It can only use the data passed in the parameters and perform pure computations. +// validateRoleAndBuildOutput validates the role and builds the initial model +// output. It is used during request materialization, so it must stay as a pure +// computation over the provided inputs. func validateRoleAndBuildOutput( provider Provider, elevateRequest ElevateRequestInternal, diff --git a/internal/models/provider_workflows_request_test.go b/internal/models/provider_workflows_request_test.go new file mode 100644 index 00000000..206aa60c --- /dev/null +++ b/internal/models/provider_workflows_request_test.go @@ -0,0 +1,161 @@ +package models_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thand-io/agent/internal/config" + "github.com/thand-io/agent/internal/models" +) + +type authorizeRoleRequestConfigStub struct { + *config.Config + identity *models.Identity + identityErr error + lastCompositeIdentity *models.Identity +} + +func (s *authorizeRoleRequestConfigStub) GetIdentity(byEmail string) (*models.Identity, error) { + if s.identityErr != nil { + return nil, s.identityErr + } + return s.identity, nil +} + +func (s *authorizeRoleRequestConfigStub) GetCompositeRoleForWorkflow( + identity *models.Identity, + baseRole *models.Role, + workflowID string, + providers ...models.Provider, +) (*models.CompositeRole, error) { + s.lastCompositeIdentity = identity + return &models.CompositeRole{ + Role: *baseRole, + Composite: false, + }, nil +} + +type authorizeRoleRequestProvider struct { + *models.BaseProvider +} + +func (p *authorizeRoleRequestProvider) ValidateRole( + ctx context.Context, + user *models.Identity, + role *models.Role, +) (map[string]any, error) { + return map[string]any{}, nil +} + +func newAuthorizeRoleRequestProvider() models.Provider { + return &authorizeRoleRequestProvider{ + BaseProvider: models.NewBaseProvider( + "test-provider", + models.ProviderConfig{ + Name: "Test Provider", + Provider: "test", + }, + models.NewProviderCapabilities(), + ), + } +} + +func newAuthorizeRoleRequestRole() *models.Role { + return &models.Role{ + Name: "local_sudo", + } +} + +func TestCreateAuthorizeRoleRequestUsesResolvedIdentitySnapshot(t *testing.T) { + cfg := &authorizeRoleRequestConfigStub{ + Config: config.DefaultConfig(), + identityErr: errors.New("identity lookup should not be used"), + } + provider := newAuthorizeRoleRequestProvider() + role := newAuthorizeRoleRequestRole() + resolvedIdentity := &models.Identity{ + ID: "identity-abc", + User: &models.User{ + Email: "user@example.com", + Username: "example-user", + Name: "Example User", + }, + } + + req, err := models.CreateAuthorizeRoleRequest(cfg, provider, &models.WorkflowRoleRequest{ + WorkflowID: "wf-123", + Identity: "opaque-identity-id", + ResolvedIdentity: resolvedIdentity, + Role: role, + }) + + require.NoError(t, err) + require.NotNil(t, req) + assert.Same(t, resolvedIdentity, req.Identity) + assert.Same(t, resolvedIdentity, cfg.lastCompositeIdentity) +} + +func TestCreateAuthorizeRoleRequestFallsBackToConfigLookup(t *testing.T) { + lookedUpIdentity := &models.Identity{ + ID: "identity-def", + User: &models.User{ + Email: "resolved@example.com", + Username: "resolved-user", + }, + } + cfg := &authorizeRoleRequestConfigStub{ + Config: config.DefaultConfig(), + identity: lookedUpIdentity, + } + provider := newAuthorizeRoleRequestProvider() + role := newAuthorizeRoleRequestRole() + + req, err := models.CreateAuthorizeRoleRequest(cfg, provider, &models.WorkflowRoleRequest{ + WorkflowID: "wf-456", + Identity: "resolved@example.com", + Role: role, + }) + + require.NoError(t, err) + require.NotNil(t, req) + assert.Same(t, lookedUpIdentity, req.Identity) + assert.Same(t, lookedUpIdentity, cfg.lastCompositeIdentity) +} + +func TestCreateAuthorizeRoleRequestUsesSyntheticFallbackWhenLookupFails(t *testing.T) { + cfg := &authorizeRoleRequestConfigStub{ + Config: config.DefaultConfig(), + identityErr: errors.New("identity not found"), + } + provider := newAuthorizeRoleRequestProvider() + role := newAuthorizeRoleRequestRole() + + req, err := models.CreateAuthorizeRoleRequest(cfg, provider, &models.WorkflowRoleRequest{ + WorkflowID: "wf-789", + Identity: "opaque-identity-id", + Role: role, + }) + + require.NoError(t, err) + require.NotNil(t, req) + require.NotNil(t, req.Identity) + require.NotNil(t, req.Identity.User) + assert.Equal(t, "opaque-identity-id", req.Identity.ID) + assert.Equal(t, "opaque-identity-id", req.Identity.User.Email) + assert.Same(t, req.Identity, cfg.lastCompositeIdentity) +} + +func TestCreateChildWorkflowIDIncludesDeviceID(t *testing.T) { + childIDWithDevice := models.CreateChildWorkflowID("parent-wf", "authorizeRole", "test-provider", &models.WorkflowRoleRequest{ + Identity: "user@example.com", + DeviceID: "device-alpha", + }) + childIDWithoutDevice := models.CreateChildWorkflowID("parent-wf", "authorizeRole", "test-provider", &models.WorkflowRoleRequest{ + Identity: "user@example.com", + }) + + assert.NotEqual(t, childIDWithoutDevice, childIDWithDevice) +} diff --git a/internal/models/role_clone.go b/internal/models/role_clone.go new file mode 100644 index 00000000..63bc3617 --- /dev/null +++ b/internal/models/role_clone.go @@ -0,0 +1,15 @@ +package models + +func CloneRole(role *Role) *Role { + if role == nil { + return nil + } + + clone := *role + clone.Authenticators = append([]string(nil), role.Authenticators...) + clone.Workflows = append([]string(nil), role.Workflows...) + clone.Providers = append([]string(nil), role.Providers...) + clone.Inherits = append([]string(nil), role.Inherits...) + + return &clone +} diff --git a/internal/models/workflow_elevate_task.go b/internal/models/workflow_elevate_task.go index 354041f9..8b65327e 100644 --- a/internal/models/workflow_elevate_task.go +++ b/internal/models/workflow_elevate_task.go @@ -131,6 +131,28 @@ func (r *ElevateWorkflowTask) GetContextAsElevationRequest() (*ElevateRequestInt return &req, nil } +func (r *ElevateWorkflowTask) SetExecutionPlan(plan *ExecutionPlan) { + r.SetContextKeyValue(sdkConstants.VarsContextExecutionPlan, plan) +} + +func (r *ElevateWorkflowTask) GetContextAsExecutionPlan() (*ExecutionPlan, error) { + contextMap := r.GetContextAsMap() + rawPlan, ok := contextMap[sdkConstants.VarsContextExecutionPlan] + if !ok || rawPlan == nil { + return nil, fmt.Errorf("execution plan is missing from workflow context") + } + + var plan ExecutionPlan + if err := common.ConvertInterfaceToInterface(rawPlan, &plan); err != nil { + return nil, fmt.Errorf("failed to decode context as ExecutionPlan: %w", err) + } + if !plan.IsValid() { + return nil, fmt.Errorf("execution plan is missing entries") + } + + return &plan, nil +} + func (r *ElevateWorkflowTask) GetUser() *User { req, err := r.GetContextAsElevationRequest() diff --git a/internal/workflows/tasks/providers/thand/authorize.go b/internal/workflows/tasks/providers/thand/authorize.go index 08db6d38..031d3bc2 100644 --- a/internal/workflows/tasks/providers/thand/authorize.go +++ b/internal/workflows/tasks/providers/thand/authorize.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "strings" "sync" "time" @@ -96,16 +97,20 @@ func (t *thandTask) buildBasicModelOutput(elevateRequest *models.ElevateRequestI // authResult holds the result of an authorization operation type authResult struct { Identity string - AuthRequest *models.WorkflowRoleRequest + EntryID string + DeviceID string + AuthRequest *models.AuthorizeRoleRequest AuthResponse *models.AuthorizeRoleResponse Error error } // authTask represents an authorization task with all necessary context type authTask struct { - ProviderName string - Identity string - AuthRequest *models.WorkflowRoleRequest + EntryID string + ProviderName string + Identity string + DeviceID string + AuthorizeRequest *models.AuthorizeRoleRequest } // executeAuthorization performs the main authorization workflow @@ -146,54 +151,39 @@ func (t *thandTask) executeAuthorization( "revocation_at": revocationDate.Format(time.RFC3339), } - // Collect all authorization tasks - var authTasks []authTask - - if len(elevateRequest.Providers) == 0 { - return nil, fmt.Errorf("no providers specified for authorization") - } - - if len(elevateRequest.Identities) == 0 { - return nil, fmt.Errorf("no identities specified for authorization") + plan, err := t.ensureExecutionPlan(workflowTask, elevateRequest) + if err != nil { + return nil, err } - for _, providerName := range elevateRequest.Providers { - - for _, identityId := range elevateRequest.Identities { - - // Check if we have tenants specified in our request. If so, we need - // to create an authorization task for each identity and tenant combination - // if there are no tenants, we just create one task per identity - if len(elevateRequest.Tenants) == 0 { - elevateRequest.Tenants = []string{""} // Use empty string to indicate no tenant - } - - for _, tenantId := range elevateRequest.Tenants { - - authReq := models.WorkflowRoleRequest{ - WorkflowID: workflowTask.GetWorkflowID(), - Identity: identityId, - Role: elevateRequest.Role, - Duration: &duration, - Tenant: tenantId, - } - - authTasks = append(authTasks, authTask{ - ProviderName: providerName, - Identity: identityId, - AuthRequest: &authReq, - }) + var authTasks []authTask + for _, entry := range plan.Entries { + if strings.TrimSpace(entry.ProviderName) == "" { + return nil, fmt.Errorf("execution plan entry is missing provider name") + } + if entry.AuthorizeRequest == nil { + return nil, fmt.Errorf("execution plan entry for provider %q is missing authorize request", entry.ProviderName) + } + identityID := identityKeyFromAuthorizeRequest(entry.AuthorizeRequest) + if identityID == "" { + return nil, fmt.Errorf("execution plan entry for provider %q is missing identity information", entry.ProviderName) + } - log.WithFields(logrus.Fields{ - "identity": identityId, - "role": authReq.Role.GetName(), - "provider": providerName, - "duration": duration, - "tenant": tenantId, - }).Info("Preparing authorization logic") + authTasks = append(authTasks, authTask{ + EntryID: entry.EntryID, + ProviderName: entry.ProviderName, + Identity: identityID, + DeviceID: entry.DeviceID, + AuthorizeRequest: models.CloneAuthorizeRoleRequest(entry.AuthorizeRequest), + }) - } - } + log.WithFields(logrus.Fields{ + "identity": identityID, + "role": entry.AuthorizeRequest.Role.GetName(), + "provider": entry.ProviderName, + "duration": duration, + "tenant": authorizeRequestTenantID(entry.AuthorizeRequest), + }).Info("Preparing authorization logic") } var authResults []authResult @@ -230,22 +220,7 @@ func (t *thandTask) executeAuthorization( } for _, req := range authTasks { - var dur *time.Duration - if req.AuthRequest.Duration != nil { - d := *req.AuthRequest.Duration - dur = &d - } - // Create a non-composite role from the workflow's base role definition - // The role will be resolved properly by the provider if needed - requests[req.Identity] = &models.AuthorizeRoleRequest{ - Identity: &models.Identity{ID: req.AuthRequest.Identity}, - Tenant: &models.ProviderTenant{ID: req.AuthRequest.Tenant}, - Role: &models.CompositeRole{ - Role: *req.AuthRequest.Role, - Composite: false, // Explicitly set - this is a base role from workflow - }, - Duration: dur, - } + requests[req.Identity] = models.CloneAuthorizeRoleRequest(req.AuthorizeRequest) } if len(returnedErrors) > 0 && len(authorizations) == 0 { @@ -286,10 +261,36 @@ func (t *thandTask) executeAuthorization( return modelOutput, nil } +func identityKeyFromAuthorizeRequest(req *models.AuthorizeRoleRequest) string { + if req == nil || req.Identity == nil { + return "" + } + if trimmed := strings.TrimSpace(req.Identity.ID); trimmed != "" { + return trimmed + } + if req.Identity.User != nil { + if trimmed := strings.TrimSpace(req.Identity.User.Email); trimmed != "" { + return trimmed + } + if trimmed := strings.TrimSpace(req.Identity.User.Username); trimmed != "" { + return trimmed + } + } + return "" +} + +func authorizeRequestTenantID(req *models.AuthorizeRoleRequest) string { + if req == nil || req.Tenant == nil { + return "" + } + return req.Tenant.ID +} + // When a Temporal context is available, it dispatches a child workflow using -// the parent workflow's task queue (typically the agent identity), assuming -// the provider is registered on that worker. Otherwise it falls back to local -// provider execution. +// the parent workflow's task queue by default. If the request carries a +// DeviceID, it waits for a fresh live route for that device and overrides the +// child workflow routing to the device's task queue instead. Otherwise it +// falls back to local provider execution. func (t *thandTask) runAuthTask( workflowTask sdkWorkflowsModel.WorkflowTaskSupport, task authTask, @@ -302,34 +303,63 @@ func (t *thandTask) runAuthTask( wfName := models.CreateTemporalProviderWorkflowName( task.ProviderName, models.TemporalAuthorizeRoleWorkflowName) + taskQueue := workflowTask.GetTaskQueue() + childTimeout := time.Duration(0) + if strings.TrimSpace(task.DeviceID) != "" { + route, remaining, err := t.waitForFreshDeviceRoute( + ctx, + task.DeviceID, + deviceDispatchBudget(task.AuthorizeRequest), + ) + if err != nil { + return authResult{ + Identity: task.Identity, + EntryID: task.EntryID, + DeviceID: task.DeviceID, + AuthRequest: task.AuthorizeRequest, + Error: err, + } + } + taskQueue = route.TaskQueue + childTimeout = remaining + } + // Create unique child workflow ID using hash of composite identifier // (provider + role + identity + tenant) to ensure uniqueness across // different identities/tenants requesting the same role childOpts := workflow.ChildWorkflowOptions{ - WorkflowID: models.CreateChildWorkflowID( + WorkflowID: models.CreateChildWorkflowIDForEntry( workflowTask.GetWorkflowID(), "authorizeRole", - task.ProviderName, - task.AuthRequest, + task.EntryID, ), - TaskQueue: workflowTask.GetTaskQueue(), + TaskQueue: taskQueue, + } + if childTimeout > 0 { + childOpts.WorkflowExecutionTimeout = childTimeout + childOpts.WorkflowRunTimeout = childTimeout } + childOpts = childWorkflowOptionsForTaskQueue(workflowTask.GetTaskQueue(), taskQueue, childOpts) ctx = workflow.WithChildOptions(ctx, childOpts) - req := task.AuthRequest + req := models.CloneAuthorizeRoleRequest(task.AuthorizeRequest) var resp models.AuthorizeRoleResponse - err := workflow.ExecuteChildWorkflow(ctx, wfName, req).Get(ctx, &resp) + err := workflow.ExecuteChildWorkflow(ctx, wfName, *req).Get(ctx, &resp) if err != nil { return authResult{ Identity: task.Identity, - AuthRequest: task.AuthRequest, + EntryID: task.EntryID, + DeviceID: task.DeviceID, + AuthRequest: task.AuthorizeRequest, Error: err, } } return authResult{ Identity: task.Identity, - AuthRequest: task.AuthRequest, + EntryID: task.EntryID, + DeviceID: task.DeviceID, + AuthRequest: task.AuthorizeRequest, AuthResponse: &resp, Error: nil, } @@ -340,26 +370,18 @@ func (t *thandTask) runAuthTask( if err != nil { return authResult{ Identity: task.Identity, - AuthRequest: task.AuthRequest, + EntryID: task.EntryID, + DeviceID: task.DeviceID, + AuthRequest: task.AuthorizeRequest, Error: fmt.Errorf("failed to get provider: %w", err), } } - authRoleReq, err := models.CreateAuthorizeRoleRequest( - t.config, - providerCall, - task.AuthRequest, - ) - if err != nil { - return authResult{ - Identity: task.Identity, - AuthRequest: task.AuthRequest, - Error: fmt.Errorf("failed to create authorize role request: %w", err), - } - } - authOut, err := providerCall.AuthorizeRole(workflowTask.GetContext(), authRoleReq) + authOut, err := providerCall.AuthorizeRole(workflowTask.GetContext(), models.CloneAuthorizeRoleRequest(task.AuthorizeRequest)) return authResult{ Identity: task.Identity, - AuthRequest: task.AuthRequest, + EntryID: task.EntryID, + DeviceID: task.DeviceID, + AuthRequest: task.AuthorizeRequest, AuthResponse: authOut, Error: err, } diff --git a/internal/workflows/tasks/providers/thand/device_routing.go b/internal/workflows/tasks/providers/thand/device_routing.go new file mode 100644 index 00000000..337ebc4e --- /dev/null +++ b/internal/workflows/tasks/providers/thand/device_routing.go @@ -0,0 +1,109 @@ +package thand + +import ( + "errors" + "fmt" + "time" + + "github.com/thand-io/agent/internal/models" + "go.temporal.io/sdk/workflow" +) + +const ( + deviceAuthorizeMaxWait = 5 * time.Minute + deviceRouteCheckInterval = 5 * time.Second + deviceRouteCheckTimeout = 10 * time.Second + deviceRouteRevokeAttemptLimit = 1 * time.Minute + deviceRouteRevokeInitialRetry = 5 * time.Second + deviceRouteRevokeMaxRetry = 1 * time.Minute +) + +var errDeviceRouteWaitExpired = errors.New("device route wait expired") + +func deviceDispatchBudget(req *models.AuthorizeRoleRequest) time.Duration { + if req == nil || req.Duration == nil || *req.Duration <= 0 { + return deviceAuthorizeMaxWait + } + if *req.Duration < deviceAuthorizeMaxWait { + return *req.Duration + } + return deviceAuthorizeMaxWait +} + +func nextDeviceRouteRetryDelay(current time.Duration) time.Duration { + if current <= 0 { + return deviceRouteRevokeInitialRetry + } + current *= 2 + if current > deviceRouteRevokeMaxRetry { + return deviceRouteRevokeMaxRetry + } + return current +} + +func childWorkflowIDForAttempt(base string, attempt int) string { + if attempt <= 0 { + return base + } + return fmt.Sprintf("%s_retry_%d", base, attempt) +} + +func (t *thandTask) resolveFreshDeviceRoute( + ctx workflow.Context, + deviceID string, +) (*models.DeviceConnectionState, error) { + ao := workflow.LocalActivityOptions{ + StartToCloseTimeout: deviceRouteCheckTimeout, + } + actx := workflow.WithLocalActivityOptions(ctx, ao) + + var route models.DeviceConnectionState + err := workflow.ExecuteLocalActivity( + actx, + models.TemporalResolveFreshDeviceRouteActivityName, + deviceID, + ).Get(ctx, &route) + if err != nil { + return nil, err + } + + return &route, nil +} + +func (t *thandTask) waitForFreshDeviceRoute( + ctx workflow.Context, + deviceID string, + timeout time.Duration, +) (*models.DeviceConnectionState, time.Duration, error) { + if timeout <= 0 { + return nil, 0, fmt.Errorf("device route wait timeout must be positive") + } + + deadline := workflow.Now(ctx).Add(timeout) + for { + route, err := t.resolveFreshDeviceRoute(ctx, deviceID) + if err == nil { + remaining := deadline.Sub(workflow.Now(ctx)) + if remaining <= 0 { + remaining = time.Second + } + return route, remaining, nil + } + if !isDeviceRouteUnavailableError(err) { + return nil, 0, err + } + + remaining := deadline.Sub(workflow.Now(ctx)) + if remaining <= 0 { + return nil, 0, fmt.Errorf("%w: device %q did not become available within %s", errDeviceRouteWaitExpired, deviceID, timeout) + } + + sleepFor := deviceRouteCheckInterval + if sleepFor > remaining { + sleepFor = remaining + } + if err := workflow.Sleep(ctx, sleepFor); err != nil { + return nil, 0, err + } + } +} diff --git a/internal/workflows/tasks/providers/thand/device_routing_test.go b/internal/workflows/tasks/providers/thand/device_routing_test.go new file mode 100644 index 00000000..495da015 --- /dev/null +++ b/internal/workflows/tasks/providers/thand/device_routing_test.go @@ -0,0 +1,226 @@ +package thand + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thand-io/agent/internal/models" + "github.com/thand-io/agent/internal/testing/temporaltest" + sdkWorkflowsModel "github.com/thand-io/agent/sdk/workflows/models" + "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/testsuite" + "go.temporal.io/sdk/workflow" +) + +func newDeviceRoutingTestEnv() *testsuite.TestWorkflowEnvironment { + temporaltest.SeedBinaryChecksum() + var suite testsuite.WorkflowTestSuite + env := suite.NewTestWorkflowEnvironment() + return env +} + +func newDeviceRoutingWorkflowTask(workflowID, taskQueue string) *models.ElevateWorkflowTask { + task := models.NewElevateWorkflowTask(&sdkWorkflowsModel.WorkflowTask{ + WorkflowID: workflowID, + Context: map[string]any{}, + Input: map[string]any{}, + Output: map[string]any{}, + }) + task.SetTaskQueue(taskQueue) + return task +} + +func newDeviceRoutingAuthorizeRequest(duration time.Duration) *models.AuthorizeRoleRequest { + return &models.AuthorizeRoleRequest{ + Identity: &models.Identity{ + ID: "user@example.com", + User: &models.User{ + Email: "user@example.com", + }, + }, + Role: &models.CompositeRole{ + Role: *newExecutionPlanRole("local_sudo", "Local Sudo"), + Composite: false, + }, + Duration: &duration, + } +} + +func TestRunAuthTaskRoutesDeviceWorkflowToFreshRouteQueue(t *testing.T) { + t.Parallel() + + env := newDeviceRoutingTestEnv() + + env.RegisterActivityWithOptions( + func(context.Context, string) (models.DeviceConnectionState, error) { + return models.DeviceConnectionState{ + DeviceID: "device-alpha", + TaskQueue: "thand_local_workstation_alpha", + Connected: true, + }, nil + }, + activity.RegisterOptions{Name: models.TemporalResolveFreshDeviceRouteActivityName}, + ) + + authorizeWorkflowName := models.CreateTemporalProviderWorkflowName("local-elevation", models.TemporalAuthorizeRoleWorkflowName) + env.RegisterWorkflowWithOptions( + func(ctx workflow.Context, req models.AuthorizeRoleRequest) (*models.AuthorizeRoleResponse, error) { + assert.Equal(t, "thand_local_workstation_alpha", workflow.GetInfo(ctx).TaskQueueName) + require.NotNil(t, req.Identity) + return &models.AuthorizeRoleResponse{UserId: req.Identity.ID}, nil + }, + workflow.RegisterOptions{Name: authorizeWorkflowName}, + ) + + task := newDeviceRoutingWorkflowTask("wf-auth-route", "thand_local_server_alpha") + request := newDeviceRoutingAuthorizeRequest(30 * time.Second) + + env.ExecuteWorkflow(func(ctx workflow.Context) error { + task.WithTemporalContext(ctx) + + result := (&thandTask{}).runAuthTask(task, authTask{ + EntryID: "entry-1", + ProviderName: "local-elevation", + Identity: "user@example.com", + DeviceID: "device-alpha", + AuthorizeRequest: request, + }) + if result.Error != nil { + return result.Error + } + if result.AuthResponse == nil || result.AuthResponse.UserId != "user@example.com" { + return fmt.Errorf("authorize child workflow did not return the expected response") + } + if got, want := task.GetTaskQueue(), "thand_local_server_alpha"; got != want { + return fmt.Errorf("parent task queue = %q, want %q", got, want) + } + return nil + }) + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) +} + +func TestRunAuthTaskFailsWhenDeviceRouteNeverAppears(t *testing.T) { + t.Parallel() + + env := newDeviceRoutingTestEnv() + + var childExecutions atomic.Int32 + env.RegisterActivityWithOptions( + func(context.Context, string) (models.DeviceConnectionState, error) { + return models.DeviceConnectionState{}, temporal.NewNonRetryableApplicationError( + "device route unavailable", + "DeviceRouteUnavailable", + nil, + ) + }, + activity.RegisterOptions{Name: models.TemporalResolveFreshDeviceRouteActivityName}, + ) + + authorizeWorkflowName := models.CreateTemporalProviderWorkflowName("local-elevation", models.TemporalAuthorizeRoleWorkflowName) + env.RegisterWorkflowWithOptions( + func(workflow.Context, models.AuthorizeRoleRequest) (*models.AuthorizeRoleResponse, error) { + childExecutions.Add(1) + return &models.AuthorizeRoleResponse{}, nil + }, + workflow.RegisterOptions{Name: authorizeWorkflowName}, + ) + + task := newDeviceRoutingWorkflowTask("wf-auth-no-route", "thand_local_server_alpha") + request := newDeviceRoutingAuthorizeRequest(time.Second) + + env.ExecuteWorkflow(func(ctx workflow.Context) error { + task.WithTemporalContext(ctx) + + result := (&thandTask{}).runAuthTask(task, authTask{ + EntryID: "entry-1", + ProviderName: "local-elevation", + Identity: "user@example.com", + DeviceID: "device-alpha", + AuthorizeRequest: request, + }) + return result.Error + }) + + require.True(t, env.IsWorkflowCompleted()) + require.Error(t, env.GetWorkflowError()) + assert.ErrorContains(t, env.GetWorkflowError(), "device route wait expired") + assert.Zero(t, childExecutions.Load(), "device-local authorize child workflow should not run without a live route") +} + +func TestRunRevokeTaskWaitsForRouteToReturnAndThenUsesDeviceQueue(t *testing.T) { + t.Parallel() + + env := newDeviceRoutingTestEnv() + + var routeLookups atomic.Int32 + env.RegisterActivityWithOptions( + func(context.Context, string) (models.DeviceConnectionState, error) { + attempt := routeLookups.Add(1) + if attempt < 3 { + return models.DeviceConnectionState{}, temporal.NewNonRetryableApplicationError( + "device route unavailable", + "DeviceRouteUnavailable", + nil, + ) + } + return models.DeviceConnectionState{ + DeviceID: "device-alpha", + TaskQueue: "thand_local_workstation_alpha", + Connected: true, + }, nil + }, + activity.RegisterOptions{Name: models.TemporalResolveFreshDeviceRouteActivityName}, + ) + + revokeWorkflowName := models.CreateTemporalProviderWorkflowName("local-elevation", models.TemporalRevokeRoleWorkflowName) + env.RegisterWorkflowWithOptions( + func(ctx workflow.Context, req models.WorkflowRevokeRoleRequest) (*models.RevokeRoleResponse, error) { + assert.Equal(t, "thand_local_workstation_alpha", workflow.GetInfo(ctx).TaskQueueName) + require.NotNil(t, req.RevokeRoleRequest) + require.NotNil(t, req.RevokeRoleRequest.AuthorizeRoleRequest) + return &models.RevokeRoleResponse{}, nil + }, + workflow.RegisterOptions{Name: revokeWorkflowName}, + ) + + task := newDeviceRoutingWorkflowTask("wf-revoke-route-retry", "thand_local_server_alpha") + duration := 30 * time.Second + + env.ExecuteWorkflow(func(ctx workflow.Context) error { + task.WithTemporalContext(ctx) + + result := (&thandTask{}).runRevokeTask(task, revokeTask{ + EntryID: "entry-1", + ProviderName: "local-elevation", + Identity: "user@example.com", + DeviceID: "device-alpha", + RevokeReq: &models.WorkflowRevokeRoleRequest{ + RevokeRoleRequest: &models.RevokeRoleRequest{ + AuthorizeRoleRequest: newDeviceRoutingAuthorizeRequest(duration), + }, + AuthorizeRoleResponse: &models.AuthorizeRoleResponse{ + UserId: "user@example.com", + }, + }, + }) + if result.Error != nil { + return result.Error + } + if got, want := task.GetTaskQueue(), "thand_local_server_alpha"; got != want { + return fmt.Errorf("parent task queue = %q, want %q", got, want) + } + return nil + }) + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + assert.GreaterOrEqual(t, routeLookups.Load(), int32(3), "revoke should keep polling for the device route until it comes back") +} diff --git a/internal/workflows/tasks/providers/thand/errors.go b/internal/workflows/tasks/providers/thand/errors.go index 9f4d8ea4..f3783005 100644 --- a/internal/workflows/tasks/providers/thand/errors.go +++ b/internal/workflows/tasks/providers/thand/errors.go @@ -3,6 +3,7 @@ package thand import ( "errors" "fmt" + "strings" "go.temporal.io/sdk/temporal" ) @@ -39,3 +40,63 @@ func unwrapTemporalError(err error) error { return foundError } + +func isTemporalApplicationErrorType(err error, targetType string) bool { + if err == nil { + return false + } + + var activityErr *temporal.ActivityError + if errors.As(err, &activityErr) { + if innerErr := errors.Unwrap(activityErr); innerErr != nil { + err = innerErr + } + } + + var appErr *temporal.ApplicationError + return errors.As(err, &appErr) && appErr.Type() == targetType +} + +func isDeviceRouteUnavailableError(err error) bool { + return isTemporalApplicationErrorType(err, "DeviceRouteUnavailable") +} + +func isTemporalTimeoutError(err error) bool { + var timeoutErr *temporal.TimeoutError + return errors.As(err, &timeoutErr) +} + +func temporalErrorMessage(err error) string { + if err == nil { + return "" + } + + var appErr *temporal.ApplicationError + if errors.As(err, &appErr) { + return appErr.Message() + } + + return err.Error() +} + +func isTransientBrokerRevokeError(err error) bool { + message := strings.ToLower(temporalErrorMessage(err)) + if message == "" { + return false + } + + transientNeedles := []string{ + "underlying connection interrupted", + "connection interrupted", + "connection invalidated", + "session manually canceled", + } + + for _, needle := range transientNeedles { + if strings.Contains(message, needle) { + return true + } + } + + return false +} diff --git a/internal/workflows/tasks/providers/thand/execution_plan.go b/internal/workflows/tasks/providers/thand/execution_plan.go new file mode 100644 index 00000000..3e59cd21 --- /dev/null +++ b/internal/workflows/tasks/providers/thand/execution_plan.go @@ -0,0 +1,60 @@ +package thand + +import ( + agentConfig "github.com/thand-io/agent/internal/config" + "github.com/thand-io/agent/internal/models" + "go.temporal.io/sdk/workflow" +) + +const executionPlanActivityTimeout = 1 * models.DeviceRouteFreshnessTTL + +func (t *thandTask) ensureExecutionPlan( + workflowTask *models.ElevateWorkflowTask, + elevateRequest *models.ElevateRequestInternal, +) (*models.ExecutionPlan, error) { + if plan, err := workflowTask.GetContextAsExecutionPlan(); err == nil { + return plan, nil + } + + workflowID := workflowTask.GetWorkflowID() + if !workflowTask.HasTemporalContext() { + plan, err := agentConfig.BuildExecutionPlan(t.config, workflowID, elevateRequest) + if err != nil { + return nil, err + } + workflowTask.SetExecutionPlan(plan) + return plan, nil + } + + ctx := workflowTask.GetTemporalContext() + actx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: executionPlanActivityTimeout, + }) + + var plan models.ExecutionPlan + err := workflow.ExecuteActivity( + actx, + models.TemporalBuildExecutionPlanActivityName, + models.ExecutionPlanRequest{ + WorkflowID: workflowID, + ElevateRequest: elevateRequest, + }, + ).Get(ctx, &plan) + if err != nil { + return nil, err + } + + workflowTask.SetExecutionPlan(&plan) + return &plan, nil +} + +func (t *thandTask) requireExecutionPlan( + workflowTask *models.ElevateWorkflowTask, +) (*models.ExecutionPlan, error) { + plan, err := workflowTask.GetContextAsExecutionPlan() + if err != nil { + return nil, err + } + + return plan, nil +} diff --git a/internal/workflows/tasks/providers/thand/execution_plan_test.go b/internal/workflows/tasks/providers/thand/execution_plan_test.go new file mode 100644 index 00000000..840274d5 --- /dev/null +++ b/internal/workflows/tasks/providers/thand/execution_plan_test.go @@ -0,0 +1,284 @@ +package thand + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thand-io/agent/internal/common" + "github.com/thand-io/agent/internal/config" + "github.com/thand-io/agent/internal/models" + "github.com/thand-io/agent/internal/testing/temporaltest" + sdkWorkflowsModel "github.com/thand-io/agent/sdk/workflows/models" + "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/testsuite" + "go.temporal.io/sdk/workflow" +) + +type executionPlanTestProvider struct { + *models.BaseProvider +} + +func (p *executionPlanTestProvider) ValidateRole( + ctx context.Context, + user *models.Identity, + role *models.Role, +) (map[string]any, error) { + return map[string]any{}, nil +} + +func newExecutionPlanTestProvider(identifier string) *executionPlanTestProvider { + caps := models.NewProviderCapabilities().WithDefaultProvisioningConfiguration() + providerCfg := models.ProviderConfig{ + Name: identifier, + Provider: identifier, + Enabled: true, + Capabilities: caps, + Config: &models.BasicConfig{}, + } + + provider := &executionPlanTestProvider{ + BaseProvider: models.NewBaseProvider(identifier, providerCfg, caps), + } + provider.SetReady() + return provider +} + +type executionPlanTestConfig struct { + *config.Config + identities map[string]*models.Identity +} + +func (s *executionPlanTestConfig) GetIdentity(byEmail string) (*models.Identity, error) { + if identity, ok := s.identities[byEmail]; ok { + return identity, nil + } + return nil, fmt.Errorf("identity %q not found", byEmail) +} + +func newExecutionPlanWorkflowTask( + t *testing.T, + workflowID string, + req models.ElevateRequestInternal, +) *models.ElevateWorkflowTask { + t.Helper() + + contextMap := map[string]any{} + require.NoError(t, common.ConvertInterfaceToInterface(req, &contextMap)) + + return models.NewElevateWorkflowTask(&sdkWorkflowsModel.WorkflowTask{ + WorkflowID: workflowID, + Context: contextMap, + Input: map[string]any{}, + Output: map[string]any{}, + }) +} + +func newExecutionPlanRole(identifier, name string) *models.Role { + return &models.Role{ + Identifier: identifier, + Name: name, + Enabled: true, + Permissions: models.RolePermissions{ + Allow: models.RoleStatements{ + { + Operations: []string{"local:test"}, + }, + }, + }, + } +} + +func newExecutionPlanTestConfig(t *testing.T) *executionPlanTestConfig { + t.Helper() + + cfg := &executionPlanTestConfig{ + Config: config.DefaultConfig(), + identities: map[string]*models.Identity{ + "user@example.com": { + ID: "user@example.com", + User: &models.User{ + Email: "user@example.com", + Username: "example-user", + }, + }, + "second@example.com": { + ID: "second@example.com", + User: &models.User{ + Email: "second@example.com", + Username: "second-user", + }, + }, + }, + } + + cfg.AddProvider("test-provider", newExecutionPlanTestProvider("test-provider")) + cfg.AddProvider("local-elevation", newExecutionPlanTestProvider("local")) + return cfg +} + +func TestBuildExecutionPlanMaterializesAuthorizeRequests(t *testing.T) { + cfg := newExecutionPlanTestConfig(t) + req := models.ElevateRequestInternal{ + ElevateRequest: models.ElevateRequest{ + Role: newExecutionPlanRole("admin", "Admin"), + Providers: []string{"test-provider"}, + Workflow: "aws_simple_elevation", + Reason: "maintenance", + Duration: "30m", + Identities: []string{"user@example.com"}, + }, + } + + plan, err := config.BuildExecutionPlan(cfg, "wf-execution-plan", &req) + require.NoError(t, err) + require.Len(t, plan.Entries, 1) + + entry := plan.Entries[0] + require.NotEmpty(t, entry.EntryID) + assert.Equal(t, "test-provider", entry.ProviderName) + assert.Empty(t, entry.DeviceID) + require.NotNil(t, entry.AuthorizeRequest) + require.NotNil(t, entry.AuthorizeRequest.Identity) + assert.Equal(t, "user@example.com", entry.AuthorizeRequest.Identity.User.Email) + assert.Equal(t, "admin", entry.AuthorizeRequest.Role.GetIdentifier()) +} + +func TestBuildExecutionPlanCarriesCanonicalDeviceID(t *testing.T) { + cfg := newExecutionPlanTestConfig(t) + cfg.Devices.Definitions = map[string]models.Device{ + "device-alpha": { + ID: "device-alpha", + Name: "Device Alpha", + Enabled: true, + }, + } + + req := models.ElevateRequestInternal{ + ElevateRequest: models.ElevateRequest{ + Role: newExecutionPlanRole("admin", "Admin"), + Providers: []string{"test-provider"}, + Workflow: "aws_simple_elevation", + Device: "device-alpha", + Reason: "maintenance", + Duration: "30m", + Identities: []string{"user@example.com", "second@example.com"}, + }, + } + + plan, err := config.BuildExecutionPlan(cfg, "wf-local-sudo", &req) + require.NoError(t, err) + require.Len(t, plan.Entries, 2) + + assert.Equal(t, "device-alpha", plan.Entries[0].DeviceID) + assert.Equal(t, "device-alpha", plan.Entries[1].DeviceID) + assert.NotEqual(t, plan.Entries[0].EntryID, plan.Entries[1].EntryID) + assert.NotNil(t, plan.Entries[0].AuthorizeRequest) + assert.NotNil(t, plan.Entries[1].AuthorizeRequest) +} + +func TestEnsureExecutionPlanTemporalBuildsOnceAndCachesPlan(t *testing.T) { + // TestWorkflowEnvironment does not thread WorkerOptions.BuildID through the + // lazy activity-worker path, so seed the SDK's process-wide checksum cache + // before the first ExecuteActivity to keep binary hashing out of the + // workflow deadlock-detector critical path. + temporaltest.SeedBinaryChecksum() + + var suite testsuite.WorkflowTestSuite + env := suite.NewTestWorkflowEnvironment() + + request := models.ElevateRequestInternal{ + ElevateRequest: models.ElevateRequest{ + Role: newExecutionPlanRole("admin", "Admin"), + Providers: []string{"test-provider"}, + Workflow: "aws_simple_elevation", + Reason: "maintenance", + Duration: "30m", + Identities: []string{"user@example.com"}, + }, + } + + var activityCalls int + env.RegisterActivityWithOptions( + func(context.Context, models.ExecutionPlanRequest) (*models.ExecutionPlan, error) { + activityCalls++ + return &models.ExecutionPlan{ + WorkflowName: request.GetWorkflow(), + Entries: []models.ExecutionPlanEntry{{ + EntryID: "entry-1", + ProviderName: "test-provider", + AuthorizeRequest: &models.AuthorizeRoleRequest{ + Identity: &models.Identity{ + ID: "user@example.com", + User: &models.User{Email: "user@example.com"}, + }, + Role: &models.CompositeRole{ + Role: *newExecutionPlanRole("admin", "Admin"), + Composite: false, + }, + Duration: func() *time.Duration { + d := 30 * time.Minute + return &d + }(), + }, + }}, + }, nil + }, + activity.RegisterOptions{Name: models.TemporalBuildExecutionPlanActivityName}, + ) + + // Prebuild the workflow task outside the workflow closure so this test is + // measuring execution-plan caching, not request serialization before the + // first Temporal yield. + task := newExecutionPlanWorkflowTask(t, "wf-temporal-plan", request) + + env.ExecuteWorkflow(func(ctx workflow.Context) error { + task.WithTemporalContext(ctx) + + runner := &thandTask{} + firstPlan, err := runner.ensureExecutionPlan(task, &request) + if err != nil { + return err + } + secondPlan, err := runner.ensureExecutionPlan(task, &request) + if err != nil { + return err + } + if firstPlan == nil || secondPlan == nil { + return fmt.Errorf("execution plan was not returned") + } + if firstPlan.Entries[0].EntryID != secondPlan.Entries[0].EntryID { + return fmt.Errorf("execution plan was not reused from workflow context") + } + return nil + }) + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + assert.Equal(t, 1, activityCalls) +} + +func TestExecuteRevocationTaskRequiresRecordedExecutionPlan(t *testing.T) { + cfg := newExecutionPlanTestConfig(t) + req := models.ElevateRequestInternal{ + ElevateRequest: models.ElevateRequest{ + Role: newExecutionPlanRole("admin", "Admin"), + Providers: []string{"test-provider"}, + Workflow: "aws_simple_elevation", + Reason: "maintenance", + Duration: "30m", + Identities: []string{"user@example.com"}, + }, + } + + task := newExecutionPlanWorkflowTask(t, "wf-revoke-without-plan", req) + runner := &thandTask{config: cfg} + + output, err := runner.executeRevocationTask(task, "revoke", nil, &req, &RevokeTask{}) + require.Error(t, err) + assert.Nil(t, output) + assert.ErrorContains(t, err, "execution plan is missing") +} diff --git a/internal/workflows/tasks/providers/thand/main.go b/internal/workflows/tasks/providers/thand/main.go index f7455caa..3d73a7ce 100644 --- a/internal/workflows/tasks/providers/thand/main.go +++ b/internal/workflows/tasks/providers/thand/main.go @@ -79,6 +79,14 @@ func (t *thandTask) resolveIdentity(identity string) *models.Identity { } } +func (t *thandTask) resolveIdentitySnapshot(identity string) *models.Identity { + identityResult, err := t.config.GetIdentity(identity) + if err != nil || identityResult == nil || identityResult.User == nil { + return nil + } + return identityResult +} + func (f *thandTask) GetVersion() string { return "1.0.0" } diff --git a/internal/workflows/tasks/providers/thand/revoke.go b/internal/workflows/tasks/providers/thand/revoke.go index c53a8081..ce402d6c 100644 --- a/internal/workflows/tasks/providers/thand/revoke.go +++ b/internal/workflows/tasks/providers/thand/revoke.go @@ -3,6 +3,7 @@ package thand import ( "errors" "fmt" + "strings" "sync" "time" @@ -65,10 +66,47 @@ type revokeResult struct { // revokeTask represents a revocation task with all necessary context type revokeTask struct { - ProviderName string - Identity string - RevokeReq *models.WorkflowRevokeRoleRequest - AuthorizeResponse *models.AuthorizeRoleResponse + EntryID string + ProviderName string + Identity string + DeviceID string + RevokeReq *models.WorkflowRevokeRoleRequest +} + +func hydrateAuthorizeResponse( + workflowTask *models.ElevateWorkflowTask, + identityID string, + log *sdkWorkflowsModel.LogBuilder, +) *models.AuthorizeRoleResponse { + req := workflowTask.GetContextAsMap() + if req == nil { + return nil + } + + authorizationsMap, ok := req["authorizations"] + if !ok { + log.WithField("identity", identityID).Debug("No authorizations found in context for revocation") + return nil + } + + if objectMap, ok := authorizationsMap.(map[string]any); ok { + if identityMap, ok := objectMap[identityID].(map[string]any); ok { + localResponse := models.AuthorizeRoleResponse{} + if err := common.ConvertMapToInterface(identityMap, &localResponse); err != nil { + log.WithError(err).WithField("identity", identityID).Warn("Failed to convert authorize response") + return nil + } + return &localResponse + } + } + + if authzMap, ok := authorizationsMap.(map[string]*models.AuthorizeRoleResponse); ok { + if authResp, ok := authzMap[identityID]; ok { + return authResp + } + } + + return nil } func (t *thandTask) executeRevocationTask( @@ -85,11 +123,6 @@ func (t *thandTask) executeRevocationTask( log := workflowTask.GetLogger() - duration, err := elevateRequest.AsDuration() - if err != nil { - return nil, fmt.Errorf("failed to get duration: %w", err) - } - revokedAt := time.Now().UTC() modelOutput := map[string]any{ @@ -97,78 +130,49 @@ func (t *thandTask) executeRevocationTask( "revoked_at": revokedAt.Format(time.RFC3339), } - // Collect all revocation tasks - var revokeTasks []revokeTask - - for _, providerName := range elevateRequest.Providers { - - for _, identityId := range elevateRequest.Identities { - - var authorizeResponse *models.AuthorizeRoleResponse - - // Try to hydrate the authorization response for this identity - req := workflowTask.GetContextAsMap() - if req != nil { - - authorizationsMap, ok := req["authorizations"] - - if !ok { - log.WithField("identity", identityId).Debug("No authorizations found in context for revocation") - continue - } - - if objectMap, ok := authorizationsMap.(map[string]any); ok { - if identityMap, ok := objectMap[identityId].(map[string]any); ok { - localResponse := models.AuthorizeRoleResponse{} - if err := common.ConvertMapToInterface(identityMap, &localResponse); err != nil { - log.WithError(err).WithField("identity", identityId).Warn("Failed to convert authorize response") - } - authorizeResponse = &localResponse - } - } else if authzMap, ok := authorizationsMap.(map[string]*models.AuthorizeRoleResponse); ok { - if authResp, ok := authzMap[identityId]; ok { - authorizeResponse = authResp - } - } - } - - // Check if we have tenants specified in our request. If so, we need - // to create a revocation task for each identity and tenant combination - // If there are no tenants, we just create one task per identity - tenantsToProcess := elevateRequest.Tenants - if len(tenantsToProcess) == 0 { - tenantsToProcess = []string{""} // Use empty string to indicate no tenant - } - - for _, tenantID := range tenantsToProcess { - - revokeReq := models.WorkflowRevokeRoleRequest{ - RevokeRoleRequest: &models.WorkflowRoleRequest{ - WorkflowID: workflowTask.GetWorkflowID(), - Identity: identityId, - Role: elevateRequest.Role, - Duration: &duration, - Tenant: tenantID, - }, - AuthorizeRoleResponse: authorizeResponse, - } + plan, err := t.requireExecutionPlan(workflowTask) + if err != nil { + return nil, err + } - revokeTasks = append(revokeTasks, revokeTask{ - ProviderName: providerName, - Identity: identityId, - RevokeReq: &revokeReq, - AuthorizeResponse: authorizeResponse, - }) + var revokeTasks []revokeTask + for _, entry := range plan.Entries { + if strings.TrimSpace(entry.ProviderName) == "" { + return nil, fmt.Errorf("execution plan entry is missing provider name") + } + if entry.AuthorizeRequest == nil { + return nil, fmt.Errorf("execution plan entry for provider %q is missing authorize request", entry.ProviderName) + } - log.WithFields(logrus.Fields{ - "user": identityId, - "role": elevateRequest.Role.GetName(), - "provider": providerName, - "duration": duration, - "tenant": tenantID, - }).Info("Preparing revocation logic") - } + identityID := identityKeyFromAuthorizeRequest(entry.AuthorizeRequest) + if identityID == "" { + return nil, fmt.Errorf("execution plan entry for provider %q is missing identity information", entry.ProviderName) } + authorizeResponse := hydrateAuthorizeResponse(workflowTask, identityID, log) + + revokeReq := models.WorkflowRevokeRoleRequest{ + RevokeRoleRequest: &models.RevokeRoleRequest{ + AuthorizeRoleRequest: models.CloneAuthorizeRoleRequest(entry.AuthorizeRequest), + AuthorizeRoleResponse: authorizeResponse, + }, + AuthorizeRoleResponse: authorizeResponse, + } + + revokeTasks = append(revokeTasks, revokeTask{ + EntryID: entry.EntryID, + ProviderName: entry.ProviderName, + Identity: identityID, + DeviceID: entry.DeviceID, + RevokeReq: &revokeReq, + }) + + log.WithFields(logrus.Fields{ + "user": identityID, + "role": entry.AuthorizeRequest.Role.GetName(), + "provider": entry.ProviderName, + "duration": entry.AuthorizeRequest.Duration, + "tenant": authorizeRequestTenantID(entry.AuthorizeRequest), + }).Info("Preparing revocation logic") } var revokeResults []revokeResult @@ -227,9 +231,11 @@ func (t *thandTask) executeRevocationTask( } // runRevokeTask executes a single revocation task and returns its result. -// When a Temporal context is available, it dispatches a child workflow to the -// agent that has the provider registered (routed via TaskQueue = ProviderName). -// Otherwise it falls back to local provider execution. +// When a Temporal context is available, it dispatches a child workflow using +// the parent workflow's task queue by default. If the request carries a +// DeviceID, it waits for a fresh live route for that device and retries until +// it can reconcile revocation on the device's task queue instead. Otherwise it +// falls back to local provider execution. func (t *thandTask) runRevokeTask( workflowTask sdkWorkflowsModel.WorkflowTaskSupport, task revokeTask, @@ -242,34 +248,83 @@ func (t *thandTask) runRevokeTask( wfName := models.CreateTemporalProviderWorkflowName( task.ProviderName, models.TemporalRevokeRoleWorkflowName) - // Create unique child workflow ID using hash of composite identifier - // (provider + role + identity + tenant) to ensure uniqueness across - // different identities/tenants requesting the same role - childOpts := workflow.ChildWorkflowOptions{ - WorkflowID: models.CreateChildWorkflowID( - workflowTask.GetWorkflowID(), - "revokeRole", - task.ProviderName, - task.RevokeReq.RevokeRoleRequest, - ), - TaskQueue: workflowTask.GetTaskQueue(), - } - ctx = workflow.WithChildOptions(ctx, childOpts) + taskQueue := workflowTask.GetTaskQueue() + deviceID := "" + deviceID = strings.TrimSpace(task.DeviceID) - req := task.RevokeReq + baseWorkflowID := models.CreateChildWorkflowIDForEntry( + workflowTask.GetWorkflowID(), + "revokeRole", + task.EntryID, + ) - var resp models.RevokeRoleResponse - err := workflow.ExecuteChildWorkflow(ctx, wfName, req).Get(ctx, &resp) - if err != nil { - return revokeResult{ - Identity: task.Identity, - Error: err, + retryDelay := deviceRouteRevokeInitialRetry + attempt := 0 + for { + if deviceID != "" { + route, _, err := t.waitForFreshDeviceRoute( + ctx, + deviceID, + deviceRouteRevokeAttemptLimit, + ) + if err != nil { + if isDeviceRouteUnavailableError(err) || errors.Is(err, errDeviceRouteWaitExpired) { + if sleepErr := workflow.Sleep(ctx, retryDelay); sleepErr != nil { + return revokeResult{ + Identity: task.Identity, + Error: sleepErr, + } + } + retryDelay = nextDeviceRouteRetryDelay(retryDelay) + attempt++ + continue + } + return revokeResult{ + Identity: task.Identity, + Error: err, + } + } + taskQueue = route.TaskQueue } - } - return revokeResult{ - Identity: task.Identity, - Output: &resp, - Error: nil, + + childOpts := workflow.ChildWorkflowOptions{ + WorkflowID: childWorkflowIDForAttempt(baseWorkflowID, attempt), + TaskQueue: taskQueue, + WorkflowExecutionTimeout: deviceRouteRevokeAttemptLimit, + WorkflowRunTimeout: deviceRouteRevokeAttemptLimit, + } + childOpts = childWorkflowOptionsForTaskQueue(workflowTask.GetTaskQueue(), taskQueue, childOpts) + childCtx := workflow.WithChildOptions(ctx, childOpts) + + req := models.WorkflowRevokeRoleRequest{ + RevokeRoleRequest: models.CloneRevokeRoleRequest(task.RevokeReq.RevokeRoleRequest), + AuthorizeRoleResponse: task.RevokeReq.AuthorizeRoleResponse, + } + + var resp models.RevokeRoleResponse + err := workflow.ExecuteChildWorkflow(childCtx, wfName, req).Get(childCtx, &resp) + if err == nil { + return revokeResult{ + Identity: task.Identity, + Output: &resp, + Error: nil, + } + } + if !isTemporalTimeoutError(err) && !isTransientBrokerRevokeError(err) { + return revokeResult{ + Identity: task.Identity, + Error: err, + } + } + + if sleepErr := workflow.Sleep(ctx, retryDelay); sleepErr != nil { + return revokeResult{ + Identity: task.Identity, + Error: sleepErr, + } + } + retryDelay = nextDeviceRouteRetryDelay(retryDelay) + attempt++ } } @@ -281,21 +336,7 @@ func (t *thandTask) runRevokeTask( Error: fmt.Errorf("failed to get provider: %w", err), } } - authRoleReq, err := models.CreateAuthorizeRoleRequest( - t.config, - providerCall, - task.RevokeReq.RevokeRoleRequest, - ) - if err != nil { - return revokeResult{ - Identity: task.Identity, - Error: fmt.Errorf("failed to create authorize role request: %w", err), - } - } - revokeOut, err := providerCall.RevokeRole(workflowTask.GetContext(), &models.RevokeRoleRequest{ - AuthorizeRoleRequest: authRoleReq, - AuthorizeRoleResponse: task.AuthorizeResponse, - }) + revokeOut, err := providerCall.RevokeRole(workflowTask.GetContext(), models.CloneRevokeRoleRequest(task.RevokeReq.RevokeRoleRequest)) return revokeResult{ Identity: task.Identity, Output: revokeOut, diff --git a/sdk/constants/workflow_elevate_task.go b/sdk/constants/workflow_elevate_task.go index e0afb63b..04bcab2f 100644 --- a/sdk/constants/workflow_elevate_task.go +++ b/sdk/constants/workflow_elevate_task.go @@ -1,10 +1,11 @@ package constants const ( - VarsContextUser = "user" - VarsContextRequest = "request" - VarsContextProviders = "providers" - VarsContextWorkflow = "workflow" - VarsContextRole = "role" - VarsContextApproved = "approved" + VarsContextUser = "user" + VarsContextRequest = "request" + VarsContextProviders = "providers" + VarsContextWorkflow = "workflow" + VarsContextRole = "role" + VarsContextApproved = "approved" + VarsContextExecutionPlan = "execution_plan" ) diff --git a/sdk/workflows/models/workflow_task.go b/sdk/workflows/models/workflow_task.go index cdf6f6bc..5d411dd1 100644 --- a/sdk/workflows/models/workflow_task.go +++ b/sdk/workflows/models/workflow_task.go @@ -31,12 +31,13 @@ import ( type ctxKey string const ( - VarsContextUser = "user" - VarsContextRequest = "request" - VarsContextProviders = "providers" - VarsContextWorkflow = "workflow" - VarsContextRole = "role" - VarsContextApproved = "approved" + VarsContextUser = "user" + VarsContextRequest = "request" + VarsContextProviders = "providers" + VarsContextWorkflow = "workflow" + VarsContextRole = "role" + VarsContextApproved = "approved" + VarsContextExecutionPlan = "execution_plan" runnerCtxKey ctxKey = "wfRunnerContext" temporalCtxKey ctxKey = "wfTemporalContext" From 0f55f82d931e0024eaae70cef5b83c1a2f961b7d Mon Sep 17 00:00:00 2001 From: Michael Weber Date: Thu, 23 Apr 2026 21:17:47 -0500 Subject: [PATCH 08/23] feat(local-sudo): add device-local sudo request and provider flow --- cmd/cli/sudo.go | 74 ++ cmd/cli/sudo_test.go | 237 +++++ docs/api/agent/elevation.md | 4 + docs/api/agent/local-sudo.md | 89 ++ docs/configuration/index.md | 1 + docs/configuration/local-sudo.md | 136 +++ docs/getting-started.md | 5 + internal/api/elevate.go | 20 + internal/api/elevate_test.go | 21 + .../config/device_local_elevation_test.go | 57 + internal/config/environment/local/roles.yaml | 18 +- .../config/execution_plan_activity_test.go | 137 +++ internal/config/local_sudo_execution_plan.go | 104 +- internal/config/providers_local.go | 5 + internal/daemon/elevate.go | 10 +- internal/daemon/elevate_prefill_test.go | 31 + internal/daemon/static/elevate_static.html | 15 + internal/models/device.go | 13 +- internal/models/local_request.go | 227 ++++ internal/models/local_request_test.go | 79 ++ internal/providers/local/capabilities.go | 13 + internal/providers/local/main.go | 977 ++++++++++++++++++ internal/providers/local/main_test.go | 408 ++++++++ internal/providers/local/schema.go | 31 + .../providers/thand/execution_plan_test.go | 34 +- sdk/workflows/models/workflow_ctx.go | 39 + sdk/workflows/models/workflow_ctx_test.go | 61 ++ 27 files changed, 2820 insertions(+), 26 deletions(-) create mode 100644 cmd/cli/sudo.go create mode 100644 cmd/cli/sudo_test.go create mode 100644 docs/api/agent/local-sudo.md create mode 100644 docs/configuration/local-sudo.md create mode 100644 internal/config/device_local_elevation_test.go create mode 100644 internal/config/execution_plan_activity_test.go create mode 100644 internal/config/providers_local.go create mode 100644 internal/daemon/elevate_prefill_test.go create mode 100644 internal/models/local_request.go create mode 100644 internal/models/local_request_test.go create mode 100644 internal/providers/local/capabilities.go create mode 100644 internal/providers/local/main.go create mode 100644 internal/providers/local/main_test.go create mode 100644 internal/providers/local/schema.go create mode 100644 sdk/workflows/models/workflow_ctx_test.go diff --git a/cmd/cli/sudo.go b/cmd/cli/sudo.go new file mode 100644 index 00000000..62b21aac --- /dev/null +++ b/cmd/cli/sudo.go @@ -0,0 +1,74 @@ +package cli + +import ( + "fmt" + + "github.com/spf13/cobra" + "github.com/thand-io/agent/internal/common" + "github.com/thand-io/agent/internal/models" +) + +var sudoCmd = &cobra.Command{ + Use: "sudo [command...]", + Short: "Request local sudo access or run a privileged command", + Long: `Request time-bound local sudo access or run a single privileged command through the local provider.`, + PreRunE: preRunClientConfigWithSessionE, + RunE: func(cmd *cobra.Command, args []string) error { + reason, _ := cmd.Flags().GetString("reason") + duration, _ := cmd.Flags().GetString("duration") + device, _ := cmd.Flags().GetString("device") + if !cmd.Flags().Changed("device") { + device = common.GetDeviceID().String() + } + request, err := buildLocalSudoElevationRequest(args, reason, duration, device) + if err != nil { + return err + } + + return MakeElevationRequest(request) + }, +} + +func buildLocalSudoElevationRequest(args []string, reason, duration, device string) (*models.ElevateRequest, error) { + if len(reason) == 0 { + return nil, fmt.Errorf("--reason is required") + } + if cfg == nil { + return nil, fmt.Errorf("configuration is not loaded") + } + + metadata := models.LocalSudoRequestMetadata{ + Mode: models.LocalSudoModeTimed, + } + + if len(args) > 0 { + metadata.Mode = models.LocalSudoModeCommand + metadata.Command = append([]string(nil), args...) + } + + role, err := cfg.GetRoleByName(models.LocalSudoRoleIdentifier) + if err != nil { + return nil, fmt.Errorf("local sudo role %q is not configured: %w", models.LocalSudoRoleIdentifier, err) + } + + request := &models.ElevateRequest{ + Role: models.CloneRole(role), + Device: device, + Reason: reason, + Duration: duration, + Metadata: metadata.AsMap(), + } + if err := models.NormalizeLocalSudoRequest(request, cfg.GetProviders().Definitions); err != nil { + return nil, err + } + + return request, nil +} + +func init() { + requestCmd.AddCommand(sudoCmd) + + sudoCmd.Flags().StringP("duration", "d", "", "Duration of timed sudo access (for example 30m or 1h)") + sudoCmd.Flags().StringP("reason", "e", "", "Reason for the sudo request") + sudoCmd.Flags().String("device", "", "Canonical device_id for local sudo execution (defaults to the current device when omitted)") +} diff --git a/cmd/cli/sudo_test.go b/cmd/cli/sudo_test.go new file mode 100644 index 00000000..942ebf68 --- /dev/null +++ b/cmd/cli/sudo_test.go @@ -0,0 +1,237 @@ +package cli + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/thand-io/agent/internal/common" + configpkg "github.com/thand-io/agent/internal/config" + "github.com/thand-io/agent/internal/models" +) + +func TestBuildLocalSudoElevationRequestTimed(t *testing.T) { + previousCfg := cfg + t.Cleanup(func() { cfg = previousCfg }) + + cfg = newTestSudoConfig("local-custom") + + request, err := buildLocalSudoElevationRequest(nil, "system maintenance", "30m", "device-alpha") + if err != nil { + t.Fatalf("buildLocalSudoElevationRequest returned error: %v", err) + } + + if got, want := request.Workflow, models.LocalSudoTimedWorkflowName; got != want { + t.Fatalf("workflow = %q, want %q", got, want) + } + if got, want := request.Duration, "30m"; got != want { + t.Fatalf("duration = %q, want %q", got, want) + } + if got, want := request.Providers[0], "local-custom"; got != want { + t.Fatalf("provider = %q, want %q", got, want) + } + if request.Metadata["mode"] != string(models.LocalSudoModeTimed) { + t.Fatalf("mode = %#v, want %q", request.Metadata["mode"], models.LocalSudoModeTimed) + } + if got, want := request.Device, "device-alpha"; got != want { + t.Fatalf("device = %#v, want %q", got, want) + } + if got, want := request.Metadata["device_id"], "device-alpha"; got != want { + t.Fatalf("device_id = %#v, want %q", got, want) + } + if !containsString(request.Role.Providers, "local-custom") { + t.Fatalf("request role providers = %#v, want provider alias included", request.Role.Providers) + } +} + +func TestBuildLocalSudoElevationRequestCommandUsesDefaultDuration(t *testing.T) { + previousCfg := cfg + t.Cleanup(func() { cfg = previousCfg }) + + cfg = newTestSudoConfig("local-elevation") + + request, err := buildLocalSudoElevationRequest([]string{"whoami"}, "check user", "", "device-beta") + if err != nil { + t.Fatalf("buildLocalSudoElevationRequest returned error: %v", err) + } + + if got, want := request.Workflow, models.LocalSudoCommandWorkflowName; got != want { + t.Fatalf("workflow = %q, want %q", got, want) + } + if got, want := request.Duration, models.LocalSudoCommandDuration; got != want { + t.Fatalf("duration = %q, want %q", got, want) + } + command, ok := request.Metadata["command"].([]string) + if !ok { + t.Fatalf("metadata command type = %T, want []string", request.Metadata["command"]) + } + if len(command) != 1 || command[0] != "whoami" { + t.Fatalf("metadata command = %#v, want [\"whoami\"]", command) + } + if got, want := request.Metadata["device_id"], "device-beta"; got != want { + t.Fatalf("device_id = %#v, want %q", got, want) + } +} + +func TestBuildLocalSudoElevationRequestRequiresTimedDuration(t *testing.T) { + previousCfg := cfg + t.Cleanup(func() { cfg = previousCfg }) + + cfg = newTestSudoConfig("local-elevation") + + if _, err := buildLocalSudoElevationRequest(nil, "missing duration", "", "device-alpha"); err == nil { + t.Fatal("expected error for missing timed duration") + } +} + +func TestBuildLocalSudoElevationRequestPrefersLocalElevationProvider(t *testing.T) { + previousCfg := cfg + t.Cleanup(func() { cfg = previousCfg }) + + cfg = &configpkg.Config{ + Providers: configpkg.ProviderDefinitionsConfig{ + Definitions: map[string]models.ProviderConfig{ + "local": { + Name: "Local", + Provider: "local", + Enabled: true, + }, + "local-elevation": { + Name: "Local Elevation", + Provider: "local", + Enabled: true, + }, + }, + }, + Roles: configpkg.RoleConfig{ + Definitions: map[string]models.Role{ + models.LocalSudoRoleIdentifier: { + Name: "Local Sudo", + Identifier: models.LocalSudoRoleIdentifier, + Providers: []string{"local", "local-elevation"}, + Workflows: []string{models.LocalSudoTimedWorkflowName, models.LocalSudoCommandWorkflowName}, + Permissions: models.RolePermissions{ + Allow: models.RoleStatements{ + {Operations: []string{"local:sudo:*"}}, + }, + }, + }, + }, + }, + } + + request, err := buildLocalSudoElevationRequest(nil, "system maintenance", "30m", "device-alpha") + if err != nil { + t.Fatalf("buildLocalSudoElevationRequest returned error: %v", err) + } + + if got, want := request.Providers[0], "local-elevation"; got != want { + t.Fatalf("provider = %q, want %q", got, want) + } +} + +func newTestSudoConfig(providerName string) *configpkg.Config { + return &configpkg.Config{ + Providers: configpkg.ProviderDefinitionsConfig{ + Definitions: map[string]models.ProviderConfig{ + providerName: { + Name: "Local", + Provider: "local", + Enabled: true, + }, + }, + }, + Roles: configpkg.RoleConfig{ + Definitions: map[string]models.Role{ + models.LocalSudoRoleIdentifier: { + Name: "Local Sudo", + Identifier: models.LocalSudoRoleIdentifier, + Providers: []string{"local", "local-elevation"}, + Workflows: []string{models.LocalSudoTimedWorkflowName, models.LocalSudoCommandWorkflowName}, + Permissions: models.RolePermissions{ + Allow: models.RoleStatements{ + {Operations: []string{"local:sudo:*"}}, + }, + }, + }, + }, + }, + } +} + +func containsString(values []string, target string) bool { + for _, value := range values { + if value == target { + return true + } + } + return false +} + +func TestBuildLocalSudoElevationRequestRequiresConfiguredEnvironment(t *testing.T) { + previousCfg := cfg + t.Cleanup(func() { cfg = previousCfg }) + + cfg = nil + + if _, err := buildLocalSudoElevationRequest(nil, "system maintenance", "30m", ""); err == nil { + t.Fatal("expected error when config is unavailable") + } +} + +func TestSudoCommandDefaultsDeviceToCurrentMachineWhenFlagOmitted(t *testing.T) { + previousCfg := cfg + t.Cleanup(func() { cfg = previousCfg }) + + cfg = newTestSudoConfig("local-elevation") + + cmd := &cobra.Command{Use: "sudo"} + cmd.Flags().String("device", "", "") + + device, err := cmd.Flags().GetString("device") + if err != nil { + t.Fatalf("GetString(device) returned error: %v", err) + } + if cmd.Flags().Changed("device") { + t.Fatal("device flag should not be marked changed when omitted") + } + if !cmd.Flags().Changed("device") { + device = common.GetDeviceID().String() + } + + request, err := buildLocalSudoElevationRequest(nil, "system maintenance", "30m", device) + if err != nil { + t.Fatalf("buildLocalSudoElevationRequest returned error: %v", err) + } + + if got, want := request.Device, common.GetDeviceID().String(); got != want { + t.Fatalf("device = %q, want %q", got, want) + } + if got, want := request.Metadata["device_id"], common.GetDeviceID().String(); got != want { + t.Fatalf("metadata device_id = %#v, want %q", got, want) + } +} + +func TestSudoCommandPreservesExplicitEmptyDeviceFlag(t *testing.T) { + previousCfg := cfg + t.Cleanup(func() { cfg = previousCfg }) + + cfg = newTestSudoConfig("local-elevation") + + cmd := &cobra.Command{Use: "sudo"} + cmd.Flags().String("device", "", "") + if err := cmd.Flags().Set("device", ""); err != nil { + t.Fatalf("Set(device) returned error: %v", err) + } + + device, err := cmd.Flags().GetString("device") + if err != nil { + t.Fatalf("GetString(device) returned error: %v", err) + } + if !cmd.Flags().Changed("device") { + t.Fatal("device flag should be marked changed when explicitly set") + } + + if _, err := buildLocalSudoElevationRequest(nil, "system maintenance", "30m", device); err == nil { + t.Fatal("expected explicit empty device flag to remain empty and fail validation") + } +} diff --git a/docs/api/agent/elevation.md b/docs/api/agent/elevation.md index f6e4cc92..8f698e6d 100644 --- a/docs/api/agent/elevation.md +++ b/docs/api/agent/elevation.md @@ -187,3 +187,7 @@ Raw encrypted workflow state or task token for resuming workflows. - Used internally by workflow engine to resume paused workflows - State parameter contains encrypted workflow context - Supports both query parameter and body-based resumption + +## Related Guides + +- [Local Sudo Usage](local-sudo.md) diff --git a/docs/api/agent/local-sudo.md b/docs/api/agent/local-sudo.md new file mode 100644 index 00000000..e8ebc29a --- /dev/null +++ b/docs/api/agent/local-sudo.md @@ -0,0 +1,89 @@ +--- +layout: default +title: Local Sudo +parent: Agent +grand_parent: API Reference +nav_order: 5 +--- + +# Local Sudo + +Local sudo lets thand request short-lived privileged access on a specific registered device. + +Use this feature when you want a normal thand approval workflow to grant temporary local administrative access on a machine that is running a thand agent. + +## Request Types + +Local sudo supports two modes: + +- timed access, which grants sudo for a bounded lease +- command mode, which runs a specific command and cleans up immediately afterward + +## Requesting Timed Local Sudo + +CLI example: + +```bash +thand request sudo --device 11111111-2222-3333-4444-555555555555 --duration 30m --reason "System maintenance" +``` + +If `--device` is omitted, the CLI defaults to the current machine's canonical `device_id`. +If `--device` is provided explicitly, the CLI uses that exact value, even if it is empty. + +Static web example: + +```text +/api/v1/elevate?role=local_sudo&device=11111111-2222-3333-4444-555555555555&duration=PT30M&reason=System+maintenance +``` + +## Requesting Command Mode + +CLI example: + +```bash +thand request sudo --device 11111111-2222-3333-4444-555555555555 --command softwareupdate --command -i --command -a --reason "Install updates" +``` + +Command mode defaults to a short duration window and removes the local grant immediately after the command finishes. + +## Device Availability + +If the target device is offline, local sudo does not fail immediately. + +Authorize waits for a fresh device route for a bounded window: + +- up to the requested sudo duration +- capped at 5 minutes + +If the device does not reconnect in that window, the request fails instead of succeeding unexpectedly later. + +## Workflow Behavior + +Local sudo resolves device-local execution details such as the target device and local account mapping as part of the internal execution-planning work that runs at the start of `authorize`. + +That planning step reads shared device policy from the Temporal-backed device-definition registry rather than depending on the handling server having the target device configured locally. + +If you are authoring or reviewing workflows, see [Workflow Tasks](/configuration/workflows/tasks.html) for the `authorize` lifecycle and execution-planning behavior. + +## Revoke Behavior + +Timed revoke is reconciliation-oriented. + +- if the device is online, revoke is dispatched immediately +- if the device is offline, revoke remains pending until the device reconnects and the server can reconcile state + +Timed access is still expected to expire locally on the device based on the local lease. The pending revoke exists so the workflow can converge and leave an accurate audit trail. + +## Copy / Resume URLs + +The static request page preserves the `device` field in copied request URLs so the target device stays attached when the request is reopened later. + +The `device` value is the canonical `device_id`. Operators can print the current machine's device ID with `thand config device-id`. + +Live device routing is also keyed by that same `device_id`. + +## Related Docs + +- [Local Sudo Configuration](/configuration/local-sudo.html) +- [Elevation (Access Request) Endpoints](/api/agent/elevation.html) +- [Workflow Tasks](/configuration/workflows/tasks.html) diff --git a/docs/configuration/index.md b/docs/configuration/index.md index bcc37590..600533e5 100644 --- a/docs/configuration/index.md +++ b/docs/configuration/index.md @@ -85,3 +85,4 @@ export THAND_PROVIDERS_AWS_REGION="us-west-2" - **[Providers](providers)** - Provider configurations - **[Roles](roles)** - Role definitions and mappings - **[Workflows](workflows)** - Custom approval workflows +- **[Local Sudo](local-sudo)** - Device-local sudo configuration diff --git a/docs/configuration/local-sudo.md b/docs/configuration/local-sudo.md new file mode 100644 index 00000000..5ca3caa3 --- /dev/null +++ b/docs/configuration/local-sudo.md @@ -0,0 +1,136 @@ +--- +layout: default +title: Local Sudo +parent: Configuration +nav_order: 8 +--- + +# Local Sudo Configuration + +Local sudo is configured in two parts: + +- the local provider enables the device-local execution backend +- per-device policy decides who can request sudo on a given machine and which local account to use + +## Required Configuration + +Minimal device policy example: + +```yaml +devices: + device-alpha: + device_id: "11111111-2222-3333-4444-555555555555" + name: "Example Workstation" + enabled: true + local_elevation: + enabled: true + accounts: + - email: "user@example.com" + local_username: "localuser" +``` + +Production agents use a generated machine-derived `device_id`. You can print the current machine's value with: + +```bash +thand config device-id +``` + +For deterministic dev/CI setups, non-production builds may override the generated value with `THAND_DEV_DEVICE_ID_OVERRIDE`. That override path is intentionally not available in production binaries. + +## Provider Configuration + +If you are using the built-in local provider defaults, you usually do not need to restate the provider stanza at all. + +The embedded `local_sudo` role ships with the timed sudo workflow only. Command mode remains available to custom roles and workflows, but it is not part of the default local sudo role. + +Example: + +```yaml +providers: + local-elevation: + provider: local + enabled: true +``` + +## Account Mapping + +Per-device account mappings decide which local account receives sudo. + +You can match by: + +- `identity` +- `email` +- `username` + +Example: + +```yaml +devices: + device-alpha: + device_id: "11111111-2222-3333-4444-555555555555" + enabled: true + local_elevation: + enabled: true + accounts: + - identity: "identity-abc123" + local_username: "localuser" + - email: "user@example.com" + local_username: "localuser" +``` + +## Allowed Modes + +You can restrict a device to specific local-sudo modes. + +```yaml +devices: + device-alpha: + device_id: "11111111-2222-3333-4444-555555555555" + local_elevation: + enabled: true + allowed_modes: + - timed + - command +``` + +If `allowed_modes` is omitted, both timed and command mode are allowed at the device-policy layer, but the embedded `local_sudo` role still exposes only the timed workflow by default. + +## Guardrails + +You can add guardrails for unsafe local targets. + +```yaml +devices: + device-alpha: + device_id: "11111111-2222-3333-4444-555555555555" + local_elevation: + enabled: true + denied_usernames: + - root + - daemon + - nobody + allowed_uid_ranges: + - "1000-60000" +``` + +`denied_usernames` blocks sensitive local accounts even if they are mapped accidentally. + +`allowed_uid_ranges` constrains requests to human-style local accounts instead of system accounts. + +## Operational Notes + +- local sudo routes only to fresh live agent registration state +- live routes are published by agents to the shared Temporal device-route registry +- local-sudo execution planning reads device policy from the shared Temporal device-definition registry +- the device identity is the canonical `device_id` +- operators can print the local machine device ID with `thand config device-id` +- `thand request sudo` defaults to the current machine when `--device` is omitted +- static `execution_target` routing is no longer used +- local-sudo execution planning runs internally at the start of `authorize` +- authorize waits for the device for a bounded window + +## Related Docs + +- [Local Sudo Usage](/api/agent/local-sudo.html) +- [Workflow Tasks](/configuration/workflows/tasks.html) +- [Configuration](/configuration/) diff --git a/docs/getting-started.md b/docs/getting-started.md index bc4f31ee..08507977 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -165,6 +165,11 @@ Request temporary sudo access on your local machine: thand-agent request sudo --duration 30m --reason "System maintenance" ``` +For device-targeted usage and configuration details, see: + +- [Local Sudo Usage](api/agent/local-sudo.md) +- [Local Sudo Configuration](configuration/local-sudo.md) + ## Next Steps - **[Environment Setup](../environments/)** - Configure Thand for specific environments diff --git a/internal/api/elevate.go b/internal/api/elevate.go index de377abb..33050bc3 100644 --- a/internal/api/elevate.go +++ b/internal/api/elevate.go @@ -6,7 +6,9 @@ import ( "context" "errors" "fmt" + "strings" + swfCtx "github.com/serverlessworkflow/sdk-go/v3/impl/ctx" "github.com/sirupsen/logrus" "github.com/thand-io/agent/internal/models" sdkConstants "github.com/thand-io/agent/sdk/constants" @@ -56,6 +58,10 @@ func (s *Service) Elevate(ctx context.Context, input ElevationInput) (*models.Wo request := input.Request + if err := models.NormalizeLocalSudoRequest(&request, s.cfg.GetProviderDefinitions()); err != nil { + return nil, fmt.Errorf("failed to normalize elevation request: %w", err) + } + if input.User != nil { exportableSession := &models.ExportableSession{ Session: input.User, @@ -115,6 +121,13 @@ func (s *Service) Resume(ctx context.Context, input ResumeInput) (*models.Elevat workflowTask, err := s.workflows.ResumeWorkflow(workflow) if err != nil { + if isAlreadyCompletedResumeError(err) { + logrus.WithFields(logrus.Fields{ + "workflow_id": workflow.GetWorkflowID(), + }).Debug("elevation resume: workflow already completed") + workflow.Status = swfCtx.CompletedStatus + return workflow, nil + } return nil, fmt.Errorf("failed to resume workflow: %w", err) } @@ -128,3 +141,10 @@ func (s *Service) Resume(ctx context.Context, input ResumeInput) (*models.Elevat return workflowTask, nil } + +func isAlreadyCompletedResumeError(err error) bool { + if err == nil { + return false + } + return strings.Contains(strings.ToLower(err.Error()), "workflow execution already completed") +} diff --git a/internal/api/elevate_test.go b/internal/api/elevate_test.go index adf1996f..40206c3c 100644 --- a/internal/api/elevate_test.go +++ b/internal/api/elevate_test.go @@ -7,6 +7,7 @@ import ( "time" cloudevents "github.com/cloudevents/sdk-go/v2" + swfCtx "github.com/serverlessworkflow/sdk-go/v3/impl/ctx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thand-io/agent/internal/config" @@ -299,3 +300,23 @@ func TestResume_RunnerError(t *testing.T) { require.Error(t, err) assert.ErrorIs(t, err, sentinel) } + +func TestResume_WorkflowAlreadyCompletedReturnsExistingWorkflow(t *testing.T) { + task := newElevateTask() + runner := &mockRunner{ + resumeFn: func(_ *models.ElevateWorkflowTask) (*models.ElevateWorkflowTask, error) { + return nil, errors.New("failed to signal workflow: workflow execution already completed") + }, + } + svc := newTestService(true, runner) + + result, err := svc.Resume(context.Background(), ResumeInput{ + Workflow: task, + User: &models.User{Email: "eve@example.com"}, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Same(t, task, result) + assert.Equal(t, swfCtx.CompletedStatus, result.GetStatus()) +} diff --git a/internal/config/device_local_elevation_test.go b/internal/config/device_local_elevation_test.go new file mode 100644 index 00000000..fc05196c --- /dev/null +++ b/internal/config/device_local_elevation_test.go @@ -0,0 +1,57 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadConfigDecodesDeviceLocalElevationFields(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + configBody := []byte(` +version: "1.0" +environment: + name: test + platform: local +devices: + device-alpha: + device_id: "device-alpha" + name: "Example Workstation" + enabled: true + local_elevation: + enabled: true + allowed_modes: + - timed + - command + accounts: + - email: user@example.com + local_username: exampleuser +`) + if err := os.WriteFile(configPath, configBody, 0600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + cfg, err := Load(configPath) + if err != nil { + t.Fatalf("Load returned error: %v", err) + } + + device, err := cfg.GetDevice("device-alpha") + if err != nil { + t.Fatalf("GetDevice returned error: %v", err) + } + + if got, want := device.ID, "device-alpha"; got != want { + t.Fatalf("device id = %q, want %q", got, want) + } + if device.LocalElevation == nil { + t.Fatal("expected local_elevation to be decoded") + } + if got, want := len(device.LocalElevation.Accounts), 1; got != want { + t.Fatalf("accounts len = %d, want %d", got, want) + } + if got, want := device.LocalElevation.Accounts[0].LocalUsername, "exampleuser"; got != want { + t.Fatalf("local username = %q, want %q", got, want) + } +} diff --git a/internal/config/environment/local/roles.yaml b/internal/config/environment/local/roles.yaml index a4395b88..9c6d62c1 100644 --- a/internal/config/environment/local/roles.yaml +++ b/internal/config/environment/local/roles.yaml @@ -19,6 +19,7 @@ roles: - path:/root # Only allow access to /root directory providers: - local + - local-elevation enabled: true # Power Users Group @@ -41,6 +42,7 @@ roles: - path:/opt providers: - local + - local-elevation enabled: true # Users Group @@ -62,6 +64,20 @@ roles: - path:/var/tmp providers: - local + - local-elevation + enabled: true + + local_sudo: + name: Local Sudo + description: Time-bound local sudo access + workflows: + - local_sudo_timed_elevation + permissions: + allow: + - local:sudo:* + providers: + - local + - local-elevation enabled: true # Operators Group @@ -84,5 +100,5 @@ roles: - path:/etc/systemd providers: - local + - local-elevation enabled: true - diff --git a/internal/config/execution_plan_activity_test.go b/internal/config/execution_plan_activity_test.go new file mode 100644 index 00000000..ac7662ea --- /dev/null +++ b/internal/config/execution_plan_activity_test.go @@ -0,0 +1,137 @@ +package config + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thand-io/agent/internal/models" +) + +type executionPlanActivityTestProvider struct { + *models.BaseProvider +} + +func (p *executionPlanActivityTestProvider) ValidateRole( + ctx context.Context, + user *models.Identity, + role *models.Role, +) (map[string]any, error) { + return map[string]any{}, nil +} + +func newExecutionPlanActivityTestProvider(identifier string) *executionPlanActivityTestProvider { + caps := models.NewProviderCapabilities().WithDefaultProvisioningConfiguration() + providerCfg := models.ProviderConfig{ + Name: identifier, + Provider: identifier, + Enabled: true, + Capabilities: caps, + Config: &models.BasicConfig{}, + } + + provider := &executionPlanActivityTestProvider{ + BaseProvider: models.NewBaseProvider(identifier, providerCfg, caps), + } + provider.SetReady() + return provider +} + +func newExecutionPlanActivityRole(identifier, name string) *models.Role { + return &models.Role{ + Identifier: identifier, + Name: name, + Enabled: true, + Permissions: models.RolePermissions{ + Allow: models.RoleStatements{{Operations: []string{"local:test"}}}, + }, + } +} + +func TestBuildExecutionPlanActivityUsesSharedDeviceDefinitionsForLocalSudo(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.AddProvider("local-elevation", newExecutionPlanActivityTestProvider("local")) + + activities := &thandActivities{ + config: cfg, + lookupDeviceDefinition: func(ctx context.Context, deviceID string) (*models.Device, error) { + return &models.Device{ + ID: deviceID, + Name: "Device Alpha", + Enabled: true, + LocalElevation: &models.DeviceLocalElevationPolicy{ + Enabled: true, + AllowedModes: []string{string(models.LocalSudoModeTimed)}, + Accounts: []models.DeviceLocalElevationAccount{ + {Email: "user@example.com", LocalUsername: "workstation-user"}, + }, + DeniedUsernames: []string{"root"}, + AllowedUIDRanges: []string{"1000-60000"}, + }, + }, nil + }, + } + + plan, err := activities.BuildExecutionPlan(context.Background(), models.ExecutionPlanRequest{ + WorkflowID: "wf-local-sudo", + ElevateRequest: &models.ElevateRequestInternal{ + ElevateRequest: models.ElevateRequest{ + Role: newExecutionPlanActivityRole(models.LocalSudoRoleIdentifier, "Local Sudo"), + Providers: []string{"local-elevation"}, + Workflow: models.LocalSudoTimedWorkflowName, + Device: "device-alpha", + Reason: "maintenance", + Duration: "30m", + Identities: []string{"user@example.com"}, + Metadata: models.LocalSudoRequestMetadata{ + Mode: models.LocalSudoModeTimed, + }.AsMap(), + }, + }, + }) + require.NoError(t, err) + require.Len(t, plan.Entries, 1) + + meta, err := models.DecodeLocalSudoRequestMetadata(plan.Entries[0].AuthorizeRequest.Metadata) + require.NoError(t, err) + assert.Equal(t, "device-alpha", plan.Entries[0].DeviceID) + assert.Equal(t, "workstation-user", meta.LocalUsername) + assert.Equal(t, []string{"root"}, meta.DeniedUsernames) +} + +func TestBuildExecutionPlanActivityFailsWhenSharedDeviceDefinitionIsMissing(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.AddProvider("local-elevation", newExecutionPlanActivityTestProvider("local")) + + activities := &thandActivities{ + config: cfg, + lookupDeviceDefinition: func(ctx context.Context, deviceID string) (*models.Device, error) { + return nil, assert.AnError + }, + } + + _, err := activities.BuildExecutionPlan(context.Background(), models.ExecutionPlanRequest{ + WorkflowID: "wf-local-sudo", + ElevateRequest: &models.ElevateRequestInternal{ + ElevateRequest: models.ElevateRequest{ + Role: newExecutionPlanActivityRole(models.LocalSudoRoleIdentifier, "Local Sudo"), + Providers: []string{"local-elevation"}, + Workflow: models.LocalSudoTimedWorkflowName, + Device: "device-alpha", + Reason: "maintenance", + Duration: "30m", + Identities: []string{"user@example.com"}, + Metadata: models.LocalSudoRequestMetadata{ + Mode: models.LocalSudoModeTimed, + }.AsMap(), + }, + }, + }) + require.Error(t, err) + assert.ErrorContains(t, err, "assert.AnError general error for testing") +} diff --git a/internal/config/local_sudo_execution_plan.go b/internal/config/local_sudo_execution_plan.go index 8bcfb41c..d0dd9273 100644 --- a/internal/config/local_sudo_execution_plan.go +++ b/internal/config/local_sudo_execution_plan.go @@ -1,29 +1,107 @@ package config -import "github.com/thand-io/agent/internal/models" +import ( + "fmt" + "strings" + + "github.com/thand-io/agent/internal/models" +) -// localSudoExecutionPlanDecorator is a no-op until local sudo request shaping -// lands. Keeping the hook in the execution-plan layer lets later commits add -// the feature without reshaping this baseline. type localSudoExecutionPlanDecorator struct{} -func (localSudoExecutionPlanDecorator) Applies(*models.ElevateRequestInternal) bool { - return false +func (localSudoExecutionPlanDecorator) Applies(elevateRequest *models.ElevateRequestInternal) bool { + return elevateRequest != nil && models.IsLocalSudoRequest(&elevateRequest.ElevateRequest) } func (localSudoExecutionPlanDecorator) Decorate( - models.ConfigImpl, - *models.WorkflowRoleRequest, - *models.ElevateRequestInternal, - executionPlanBuildOptions, + cfg models.ConfigImpl, + req *models.WorkflowRoleRequest, + elevateRequest *models.ElevateRequestInternal, + opts executionPlanBuildOptions, ) error { + meta, err := buildLocalSudoRequestMetadata(cfg, elevateRequest, req.Identity, req.ResolvedIdentity, opts.LookupDeviceDefinition) + if err != nil { + return err + } + + req.DeviceID = meta.DeviceID + req.Metadata = meta.AsMap() return nil } func (localSudoExecutionPlanDecorator) Finalize( - *models.WorkflowRoleRequest, - *models.ElevateRequestInternal, - string, + req *models.WorkflowRoleRequest, + elevateRequest *models.ElevateRequestInternal, + entryID string, ) error { + meta, err := models.DecodeLocalSudoRequestMetadata(req.Metadata) + if err != nil { + return err + } + meta.GrantID = entryID + req.Metadata = meta.AsMap() return nil } + +func buildLocalSudoRequestMetadata( + cfg models.ConfigImpl, + elevateRequest *models.ElevateRequestInternal, + identityID string, + resolvedIdentity *models.Identity, + lookupDeviceDefinition func(deviceID string) (*models.Device, error), +) (models.LocalSudoRequestMetadata, error) { + meta, err := models.DecodeLocalSudoRequestMetadata(elevateRequest.Metadata) + if err != nil { + return meta, err + } + + deviceID := strings.TrimSpace(elevateRequest.Device) + if deviceID == "" { + deviceID = strings.TrimSpace(meta.DeviceID) + } + if deviceID == "" { + return meta, fmt.Errorf("local sudo request is missing a device_id") + } + + if lookupDeviceDefinition == nil { + lookupDeviceDefinition = cfg.GetDevice + } + + device, err := lookupDeviceDefinition(deviceID) + if err != nil { + return meta, err + } + if !device.Enabled { + return meta, fmt.Errorf("device %q is disabled", deviceID) + } + if device.LocalElevation == nil { + return meta, fmt.Errorf("device %q does not have local elevation configured", deviceID) + } + if !device.LocalElevation.AllowsMode(string(meta.Mode)) { + return meta, fmt.Errorf("device %q does not allow local sudo mode %q", deviceID, meta.Mode) + } + + identity := resolvedIdentity + if identity == nil { + identity, err = cfg.GetIdentity(identityID) + } + if err != nil || identity == nil { + identity = &models.Identity{ + ID: identityID, + User: &models.User{ + Email: identityID, + }, + } + } + + localUsername, err := device.LocalElevation.ResolveLocalUsername(identityID, identity) + if err != nil { + return meta, err + } + + meta.DeviceID = device.ID + meta.LocalUsername = localUsername + meta.DeniedUsernames = append([]string(nil), device.LocalElevation.DeniedUsernames...) + meta.AllowedUIDRanges = append([]string(nil), device.LocalElevation.AllowedUIDRanges...) + return meta, nil +} diff --git a/internal/config/providers_local.go b/internal/config/providers_local.go new file mode 100644 index 00000000..5772b653 --- /dev/null +++ b/internal/config/providers_local.go @@ -0,0 +1,5 @@ +package config + +import ( + _ "github.com/thand-io/agent/internal/providers/local" +) diff --git a/internal/daemon/elevate.go b/internal/daemon/elevate.go index 3714a7b9..36fdb992 100644 --- a/internal/daemon/elevate.go +++ b/internal/daemon/elevate.go @@ -20,7 +20,7 @@ import ( "github.com/thand-io/agent/internal/workflows/manager" ) -// getElevate handles GET /api/v1/elevate?role=admin&target=server&reason=maintenance +// getElevate handles GET /api/v1/elevate?role=admin&device=&reason=maintenance // // @Summary Request role elevation // @Description Request elevation to a specific role with static parameters @@ -29,6 +29,7 @@ import ( // @Produce json // @Param role query string true "Role name" // @Param provider query string true "Provider name" +// @Param device query string false "Canonical device_id for device-local workflows" // @Param reason query string true "Reason for elevation" // @Param duration query string false "Duration of elevation" // @Param workflow query string false "Workflow name" @@ -72,6 +73,7 @@ func (s *Server) getElevate(c *gin.Context) { Providers: []string{request.Provider}, Identities: request.Identities, Workflow: primaryWorkflow, + Device: request.Device, Reason: request.Reason, Duration: request.Duration, Session: request.Session, @@ -671,6 +673,7 @@ type ElevateStaticPageData struct { Identities []models.Identity `json:"identities"` Providers []string `json:"providers"` Roles []string `json:"roles"` + Device string `json:"device"` Duration string `json:"duration"` Reason string `json:"reason"` Tenants []string `json:"tenants"` @@ -682,6 +685,9 @@ func (s *Server) getElevationPagePrefill(c *gin.Context) ElevateStaticPageData { } preFilledTenants := c.QueryArray("tenants") + if len(preFilledTenants) == 0 { + preFilledTenants = c.QueryArray("tenant") + } validTenants := []string{} for _, tenantID := range preFilledTenants { tenant, err := s.Config.GetTenant(tenantID) @@ -729,6 +735,8 @@ func (s *Server) getElevationPagePrefill(c *gin.Context) ElevateStaticPageData { data.Roles = roles } + data.Device = strings.TrimSpace(c.Query("device")) + // Get duration from query parameters duration := c.Query("duration") if len(duration) > 0 { diff --git a/internal/daemon/elevate_prefill_test.go b/internal/daemon/elevate_prefill_test.go new file mode 100644 index 00000000..bcfa6988 --- /dev/null +++ b/internal/daemon/elevate_prefill_test.go @@ -0,0 +1,31 @@ +package daemon + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/thand-io/agent/internal/config" +) + +func TestGetElevationPagePrefillIncludesDevice(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + "GET", + "/elevate/static?provider=local-elevation&role=local_sudo&device=device-alpha&duration=1h&reason=test", + nil, + ) + + server := NewServer(config.DefaultConfig()) + data := server.getElevationPagePrefill(ctx) + + assert.Equal(t, []string{"local-elevation"}, data.Providers) + assert.Equal(t, []string{"local_sudo"}, data.Roles) + assert.Equal(t, "device-alpha", data.Device) + assert.Equal(t, "1h", data.Duration) + assert.Equal(t, "test", data.Reason) +} diff --git a/internal/daemon/static/elevate_static.html b/internal/daemon/static/elevate_static.html index 956a58fa..0631bdf7 100644 --- a/internal/daemon/static/elevate_static.html +++ b/internal/daemon/static/elevate_static.html @@ -144,6 +144,12 @@

Select Identities to Assign Role

+ +