diff --git a/internal/core/application/service.go b/internal/core/application/service.go index 7959fa386..a670c6528 100644 --- a/internal/core/application/service.go +++ b/internal/core/application/service.go @@ -426,21 +426,19 @@ func (s *service) Stop() { s.sweeper.stop() commitmentTxIds, err := s.repoManager.Rounds().GetSweepableRounds(ctx) - if err == nil { - tapkeys := make([]string, 0) - - for _, commitmentTxId := range commitmentTxIds { - keys, err := s.repoManager.Vtxos(). - GetVtxoPubKeysByCommitmentTxid(ctx, commitmentTxId, 0) - if err != nil { - log.WithError(err).Warn("failed to get vtxo tap keys") - continue - } - - tapkeys = append(tapkeys, keys...) + if err == nil && len(commitmentTxIds) > 0 { + tapkeys, err := s.repoManager.Vtxos(). + GetVtxoPubKeysByCommitmentTxids(ctx, commitmentTxIds, 0) + if err != nil { + log.WithError(err).Warnf( + "failed to get vtxo tap keys for %d sweepable rounds; "+ + "skipping UnwatchScripts on shutdown, wallet may keep "+ + "watching these scripts until the next restart", + len(commitmentTxIds), + ) + } else { + s.stopWatchingVtxos(tapkeys) } - - s.stopWatchingVtxos(tapkeys) } // nolint @@ -3624,6 +3622,14 @@ func (s *service) startWatchingVtxos(vtxos []domain.Vtxo) error { return s.scanner.WatchScripts(context.Background(), scripts) } +// restoreWatchingVtxos re-registers every sweepable round's vtxo pubkeys +// with the chain scanner so we resume receiving notifications after a +// restart. The pubkey lookup uses the bulk repo method +// GetVtxoPubKeysByCommitmentTxids so we issue exactly two DB queries +// (one for the round list, one for all keys) regardless of how many +// sweepable rounds exist. The cross-process WatchScripts gRPC call is +// chunked by walletclient.WatchScripts to stay below the default +// 4 MiB gRPC max-message size at large script counts. func (s *service) restoreWatchingVtxos() error { ctx := context.Background() @@ -3632,30 +3638,30 @@ func (s *service) restoreWatchingVtxos() error { return err } - total := len(commitmentTxIds) - lastMilestone := 0 - scripts := make([]string, 0) - for i, commitmentTxId := range commitmentTxIds { - tapKeys, err := s.repoManager.Vtxos().GetVtxoPubKeysByCommitmentTxid(ctx, commitmentTxId, 0) - if err != nil { - return err - } + if len(commitmentTxIds) == 0 { + return nil + } - for _, key := range tapKeys { - // skip if the key is not a valid x-only hex encoded pubkey - if len(key) != 64 { - continue - } - scripts = append(scripts, fmt.Sprintf("5120%s", key)) - } + tapKeys, err := s.repoManager.Vtxos(). + GetVtxoPubKeysByCommitmentTxids(ctx, commitmentTxIds, 0) + if err != nil { + return err + } - if milestone := (i + 1) * 100 / total / 10; milestone > lastMilestone { - lastMilestone = milestone - log.Debugf("restore watching vtxos: %d%%...", milestone*10) + scripts := make([]string, 0, len(tapKeys)) + for _, key := range tapKeys { + // Skip values that are not a 32-byte x-only pubkey encoded as 64 + // hex chars. arkd writes valid keys, but defending against a + // corrupted DB row here means a single bad pubkey cannot poison + // the entire WatchScripts gRPC payload at startup recovery. + decoded, err := hex.DecodeString(key) + if err != nil || len(decoded) != 32 { + continue } + scripts = append(scripts, fmt.Sprintf("5120%s", key)) } - if len(scripts) <= 0 { + if len(scripts) == 0 { return nil } @@ -3663,7 +3669,10 @@ func (s *service) restoreWatchingVtxos() error { return err } - log.Debugf("restored watching %d vtxo scripts", len(scripts)) + log.Debugf( + "restored watching %d vtxo scripts from %d sweepable rounds", + len(scripts), len(commitmentTxIds), + ) return nil } diff --git a/internal/core/domain/vtxo_repo.go b/internal/core/domain/vtxo_repo.go index 793ff2d74..26e186468 100644 --- a/internal/core/domain/vtxo_repo.go +++ b/internal/core/domain/vtxo_repo.go @@ -30,6 +30,11 @@ type VtxoRepository interface { ) ( []string, error, ) + GetVtxoPubKeysByCommitmentTxids( + ctx context.Context, commitmentTxids []string, withMinimumAmount uint64, + ) ( + []string, error, + ) GetPendingSpentVtxosWithPubKeys( ctx context.Context, pubkeys []string, diff --git a/internal/infrastructure/db/badger/vtxo_repo.go b/internal/infrastructure/db/badger/vtxo_repo.go index 6efd61089..2dcadb531 100644 --- a/internal/infrastructure/db/badger/vtxo_repo.go +++ b/internal/infrastructure/db/badger/vtxo_repo.go @@ -318,15 +318,17 @@ func (r *vtxoRepository) GetVtxoPubKeysByCommitmentTxid( return nil, err } - // Combine and deduplicate by pubkey + // Combine and deduplicate by pubkey. The amount comparison must be >= to + // match the WHERE v.amount >= $1 contract used by the sqlite and postgres + // backends; a VTXO with Amount equal to the filter is included. pubkeyMap := make(map[string]bool) for _, vtxo := range vtxos1 { - if vtxo.Amount > amountFilter { + if vtxo.Amount >= amountFilter { pubkeyMap[vtxo.PubKey] = true } } for _, vtxo := range vtxos2 { - if vtxo.Amount > amountFilter { + if vtxo.Amount >= amountFilter { pubkeyMap[vtxo.PubKey] = true } } @@ -339,6 +341,83 @@ func (r *vtxoRepository) GetVtxoPubKeysByCommitmentTxid( return taprootKeys, nil } +// GetVtxoPubKeysByCommitmentTxids is the bulk variant of +// GetVtxoPubKeysByCommitmentTxid. It returns the deduplicated set of vtxo +// pubkeys whose root commitment_txid is in the given list, or whose +// CommitmentTxids slice intersects the given list. badgerhold has no native +// "slice intersects set" operator, so the second scan uses a MatchFunc that +// walks the in-memory slice; the SQL backends accomplish the same with a +// JOIN against vtxo_commitment_txid in a single query. +func (r *vtxoRepository) GetVtxoPubKeysByCommitmentTxids( + ctx context.Context, commitmentTxids []string, amountFilter uint64, +) ([]string, error) { + if len(commitmentTxids) == 0 { + return nil, nil + } + + idxIfaces := make([]interface{}, len(commitmentTxids)) + for i, txid := range commitmentTxids { + idxIfaces[i] = txid + } + + // Two scans of the vtxo store: one for vtxos whose RootCommitmentTxid is in + // the set, one for vtxos whose CommitmentTxids slice intersects the set. + // badgerhold has no Contains-In, so we fall back to a single scan with a + // matcher function for the CommitmentTxids case. + query1 := badgerhold.Where("RootCommitmentTxid"). + In(idxIfaces...). + And("Amount"). + Ge(amountFilter) + vtxos1, err := r.findVtxos(ctx, query1) + if err != nil { + return nil, err + } + + wanted := make(map[string]struct{}, len(commitmentTxids)) + for _, t := range commitmentTxids { + wanted[t] = struct{}{} + } + query2 := badgerhold.Where("CommitmentTxids"). + MatchFunc(func(ra *badgerhold.RecordAccess) (bool, error) { + txids, ok := ra.Field().([]string) + if !ok { + return false, nil + } + for _, t := range txids { + if _, hit := wanted[t]; hit { + return true, nil + } + } + return false, nil + }). + And("Amount"). + Ge(amountFilter) + vtxos2, err := r.findVtxos(ctx, query2) + if err != nil { + return nil, err + } + + // Amount comparison is >= to match the sqlite/postgres + // WHERE v.amount >= $1 contract; including amount == amountFilter. + pubkeyMap := make(map[string]struct{}) + for _, vtxo := range vtxos1 { + if vtxo.Amount >= amountFilter { + pubkeyMap[vtxo.PubKey] = struct{}{} + } + } + for _, vtxo := range vtxos2 { + if vtxo.Amount >= amountFilter { + pubkeyMap[vtxo.PubKey] = struct{}{} + } + } + + taprootKeys := make([]string, 0, len(pubkeyMap)) + for pubkey := range pubkeyMap { + taprootKeys = append(taprootKeys, pubkey) + } + return taprootKeys, nil +} + func (r *vtxoRepository) GetPendingSpentVtxosWithPubKeys( ctx context.Context, pubkeys []string, after, before int64, ) ([]domain.Vtxo, error) { diff --git a/internal/infrastructure/db/postgres/migration/20260527150000_vtxo_commitment_txid_index.down.sql b/internal/infrastructure/db/postgres/migration/20260527150000_vtxo_commitment_txid_index.down.sql new file mode 100644 index 000000000..50b15824f --- /dev/null +++ b/internal/infrastructure/db/postgres/migration/20260527150000_vtxo_commitment_txid_index.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_vtxo_commitment_txid_commitment_txid; diff --git a/internal/infrastructure/db/postgres/migration/20260527150000_vtxo_commitment_txid_index.up.sql b/internal/infrastructure/db/postgres/migration/20260527150000_vtxo_commitment_txid_index.up.sql new file mode 100644 index 000000000..b54965bb6 --- /dev/null +++ b/internal/infrastructure/db/postgres/migration/20260527150000_vtxo_commitment_txid_index.up.sql @@ -0,0 +1,2 @@ +CREATE INDEX IF NOT EXISTS idx_vtxo_commitment_txid_commitment_txid + ON vtxo_commitment_txid (commitment_txid); diff --git a/internal/infrastructure/db/postgres/sqlc/queries/query.sql.go b/internal/infrastructure/db/postgres/sqlc/queries/query.sql.go index a59cab93d..21a52f723 100644 --- a/internal/infrastructure/db/postgres/sqlc/queries/query.sql.go +++ b/internal/infrastructure/db/postgres/sqlc/queries/query.sql.go @@ -1633,7 +1633,7 @@ func (q *Queries) SelectVtxo(ctx context.Context, arg SelectVtxoParams) ([]Selec } const selectVtxoPubKeysByCommitmentTxid = `-- name: SelectVtxoPubKeysByCommitmentTxid :many -SELECT DISTINCT v.pubkey +SELECT DISTINCT v.pubkey FROM vtxo_vw v WHERE v.amount >= $1 AND (v.commitment_txid = $2 @@ -1668,6 +1668,53 @@ func (q *Queries) SelectVtxoPubKeysByCommitmentTxid(ctx context.Context, arg Sel return items, nil } +const selectVtxoPubKeysByCommitmentTxids = `-- name: SelectVtxoPubKeysByCommitmentTxids :many +SELECT DISTINCT v.pubkey +FROM vtxo v +WHERE v.amount >= $1 + AND ( + v.commitment_txid = ANY($2::text[]) + OR EXISTS ( + SELECT 1 FROM vtxo_commitment_txid vc + WHERE vc.vtxo_txid = v.txid AND vc.vtxo_vout = v.vout + AND vc.commitment_txid = ANY($2::text[]) + ) + ) +` + +type SelectVtxoPubKeysByCommitmentTxidsParams struct { + MinAmount int64 + CommitmentTxids []string +} + +// Bulk variant of SelectVtxoPubKeysByCommitmentTxid: returns the +// deduplicated set of vtxo pubkeys for any of the given commitment_txids. +// Used at startup by restoreWatchingVtxos to collapse what was an N+1 +// per-round loop into a single SQL call. The named parameter is reused +// in both IN/ANY clauses; postgres binds it once. +func (q *Queries) SelectVtxoPubKeysByCommitmentTxids(ctx context.Context, arg SelectVtxoPubKeysByCommitmentTxidsParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, selectVtxoPubKeysByCommitmentTxids, arg.MinAmount, pq.Array(arg.CommitmentTxids)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var pubkey string + if err := rows.Scan(&pubkey); err != nil { + return nil, err + } + items = append(items, pubkey) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const selectVtxosOutpointsByArkTxidRecursive = `-- name: SelectVtxosOutpointsByArkTxidRecursive :many WITH RECURSIVE descendants_chain AS ( -- seed diff --git a/internal/infrastructure/db/postgres/sqlc/query.sql b/internal/infrastructure/db/postgres/sqlc/query.sql index f24c4f701..d7c9b2947 100644 --- a/internal/infrastructure/db/postgres/sqlc/query.sql +++ b/internal/infrastructure/db/postgres/sqlc/query.sql @@ -282,12 +282,30 @@ SELECT sqlc.embed(offchain_tx_vw) FROM offchain_tx_vw WHERE txid = @txid AND COA SELECT * FROM scheduled_session ORDER BY updated_at DESC LIMIT 1; -- name: SelectVtxoPubKeysByCommitmentTxid :many -SELECT DISTINCT v.pubkey +SELECT DISTINCT v.pubkey FROM vtxo_vw v WHERE v.amount >= @min_amount AND (v.commitment_txid = @commitment_txid OR (',' || COALESCE(v.commitments::text, '') || ',') LIKE '%,' || @commitment_txid || ',%'); +-- Bulk variant of SelectVtxoPubKeysByCommitmentTxid: returns the +-- deduplicated set of vtxo pubkeys for any of the given commitment_txids. +-- Used at startup by restoreWatchingVtxos to collapse what was an N+1 +-- per-round loop into a single SQL call. The named parameter is reused +-- in both IN/ANY clauses; postgres binds it once. +-- name: SelectVtxoPubKeysByCommitmentTxids :many +SELECT DISTINCT v.pubkey +FROM vtxo v +WHERE v.amount >= @min_amount + AND ( + v.commitment_txid = ANY(@commitment_txids::text[]) + OR EXISTS ( + SELECT 1 FROM vtxo_commitment_txid vc + WHERE vc.vtxo_txid = v.txid AND vc.vtxo_vout = v.vout + AND vc.commitment_txid = ANY(@commitment_txids::text[]) + ) + ); + -- name: SelectSweepableVtxoOutpointsByCommitmentTxid :many SELECT DISTINCT v.txid AS vtxo_txid, v.vout AS vtxo_vout FROM vtxo_vw v diff --git a/internal/infrastructure/db/postgres/vtxo_repo.go b/internal/infrastructure/db/postgres/vtxo_repo.go index 158872e7d..0fb872ad9 100644 --- a/internal/infrastructure/db/postgres/vtxo_repo.go +++ b/internal/infrastructure/db/postgres/vtxo_repo.go @@ -459,6 +459,31 @@ func (v *vtxoRepository) GetVtxoPubKeysByCommitmentTxid( return taprootKeys, nil } +// GetVtxoPubKeysByCommitmentTxids is the bulk variant of +// GetVtxoPubKeysByCommitmentTxid. It returns the deduplicated set of vtxo +// pubkeys whose root commitment_txid is in the given list, or whose +// vtxo_commitment_txid join row references one of those commitment txids. +// This replaces a per-round loop in restoreWatchingVtxos / stopWatchingVtxos +// that previously fired one query per sweepable round (the N+1 pattern). +func (v *vtxoRepository) GetVtxoPubKeysByCommitmentTxids( + ctx context.Context, commitmentTxids []string, withMinimumAmount uint64, +) ([]string, error) { + if len(commitmentTxids) == 0 { + return nil, nil + } + + taprootKeys, err := v.querier.SelectVtxoPubKeysByCommitmentTxids(ctx, + queries.SelectVtxoPubKeysByCommitmentTxidsParams{ + MinAmount: int64(withMinimumAmount), + CommitmentTxids: commitmentTxids, + }) + if err != nil { + return nil, err + } + + return taprootKeys, nil +} + func (v *vtxoRepository) GetPendingSpentVtxosWithPubKeys( ctx context.Context, pubkeys []string, after, before int64, ) ([]domain.Vtxo, error) { diff --git a/internal/infrastructure/db/service_test.go b/internal/infrastructure/db/service_test.go index ee006e436..1ae51f600 100644 --- a/internal/infrastructure/db/service_test.go +++ b/internal/infrastructure/db/service_test.go @@ -1009,6 +1009,53 @@ func testVtxoRepository(t *testing.T, svc ports.RepoManager) { require.NoError(t, err) require.Empty(t, tapKeys) + // Bulk variant: must return the deduplicated union of the per-txid + // results across all provided commitment_txids. + bulkKeys, err := svc.Vtxos().GetVtxoPubKeysByCommitmentTxids( + ctx, []string{otherCommitmentTxid}, 0, + ) + require.NoError(t, err) + require.Len(t, bulkKeys, 3) + require.ElementsMatch(t, []string{"tapkey1", "tapkey2", "tapkey3"}, bulkKeys) + + bulkKeys, err = svc.Vtxos().GetVtxoPubKeysByCommitmentTxids( + ctx, []string{otherCommitmentTxid}, 3000, + ) + require.NoError(t, err) + require.ElementsMatch(t, []string{"tapkey1", "tapkey3"}, bulkKeys) + + // Combine with a known existing commitmentTxid that has keys too, + // expect the dedup'd union, no duplicates. + bulkKeys, err = svc.Vtxos().GetVtxoPubKeysByCommitmentTxids( + ctx, []string{otherCommitmentTxid, commitmentTxid}, 0, + ) + require.NoError(t, err) + seen := make(map[string]int) + for _, k := range bulkKeys { + seen[k]++ + } + for k, n := range seen { + require.Equalf(t, 1, n, "duplicate pubkey %s in bulk result", k) + } + // Verify the full union: keys from both commitment txids must be + // present (tapkey1/2/3 from otherCommitmentTxid, plus pubkey and + // pubkey2 from the earlier commitmentTxid seed). + require.Contains(t, bulkKeys, "tapkey1") + require.Contains(t, bulkKeys, "tapkey2") + require.Contains(t, bulkKeys, "tapkey3") + require.Contains(t, bulkKeys, pubkey) + require.Contains(t, bulkKeys, pubkey2) + + bulkKeys, err = svc.Vtxos().GetVtxoPubKeysByCommitmentTxids(ctx, nil, 0) + require.NoError(t, err) + require.Empty(t, bulkKeys) + + bulkKeys, err = svc.Vtxos().GetVtxoPubKeysByCommitmentTxids( + ctx, []string{nonExistentCommitmentTxid}, 0, + ) + require.NoError(t, err) + require.Empty(t, bulkKeys) + t.Run("test_get_pending_spent_vtxos", func(t *testing.T) { ctx := t.Context() diff --git a/internal/infrastructure/db/sqlite/export_test.go b/internal/infrastructure/db/sqlite/export_test.go new file mode 100644 index 000000000..1a2ee728b --- /dev/null +++ b/internal/infrastructure/db/sqlite/export_test.go @@ -0,0 +1,22 @@ +package sqlitedb + +import ( + "context" + + "github.com/arkade-os/arkd/internal/core/domain" +) + +// GetVtxoPubKeysByCommitmentTxidsBatched exposes the unexported batching +// helper to tests in sibling packages. The _test.go suffix keeps this out +// of the production binary. +func GetVtxoPubKeysByCommitmentTxidsBatched( + ctx context.Context, + repo domain.VtxoRepository, + commitmentTxids []string, + withMinimumAmount uint64, + batchSize int, +) ([]string, error) { + return repo.(*vtxoRepository).getVtxoPubKeysByCommitmentTxidsBatched( + ctx, commitmentTxids, withMinimumAmount, batchSize, + ) +} diff --git a/internal/infrastructure/db/sqlite/migration/20260527150000_vtxo_commitment_txid_index.down.sql b/internal/infrastructure/db/sqlite/migration/20260527150000_vtxo_commitment_txid_index.down.sql new file mode 100644 index 000000000..50b15824f --- /dev/null +++ b/internal/infrastructure/db/sqlite/migration/20260527150000_vtxo_commitment_txid_index.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_vtxo_commitment_txid_commitment_txid; diff --git a/internal/infrastructure/db/sqlite/migration/20260527150000_vtxo_commitment_txid_index.up.sql b/internal/infrastructure/db/sqlite/migration/20260527150000_vtxo_commitment_txid_index.up.sql new file mode 100644 index 000000000..b54965bb6 --- /dev/null +++ b/internal/infrastructure/db/sqlite/migration/20260527150000_vtxo_commitment_txid_index.up.sql @@ -0,0 +1,2 @@ +CREATE INDEX IF NOT EXISTS idx_vtxo_commitment_txid_commitment_txid + ON vtxo_commitment_txid (commitment_txid); diff --git a/internal/infrastructure/db/sqlite/sqlc/queries/query.sql.go b/internal/infrastructure/db/sqlite/sqlc/queries/query.sql.go index 0bb137d25..5372c5841 100644 --- a/internal/infrastructure/db/sqlite/sqlc/queries/query.sql.go +++ b/internal/infrastructure/db/sqlite/sqlc/queries/query.sql.go @@ -1772,6 +1772,78 @@ func (q *Queries) SelectVtxoPubKeysByCommitmentTxid(ctx context.Context, arg Sel return items, nil } +const selectVtxoPubKeysByCommitmentTxids = `-- name: SelectVtxoPubKeysByCommitmentTxids :many +SELECT DISTINCT v.pubkey +FROM vtxo v +WHERE v.amount >= ?1 + AND ( + v.commitment_txid IN (/*SLICE:commitment_txids*/?) + OR EXISTS ( + SELECT 1 FROM vtxo_commitment_txid vc + WHERE vc.vtxo_txid = v.txid AND vc.vtxo_vout = v.vout + AND vc.commitment_txid IN (/*SLICE:commitment_txids_alt*/?) + ) + ) +` + +type SelectVtxoPubKeysByCommitmentTxidsParams struct { + MinAmount int64 + CommitmentTxids []string + CommitmentTxidsAlt []string +} + +// Bulk variant of SelectVtxoPubKeysByCommitmentTxid: returns the +// deduplicated set of vtxo pubkeys for any of the given commitment_txids. +// Used at startup by restoreWatchingVtxos to collapse what was an N+1 +// per-round loop into a single SQL call. +// +// Two slice placeholders bind the same list of txids: sqlc's sqlite +// generator only rewrites the first occurrence of sqlc.slice('name') per +// query, so checking both v.commitment_txid and the join table forces +// the second IN clause to use a distinct slice name. The Go caller +// passes the same []string to both. +func (q *Queries) SelectVtxoPubKeysByCommitmentTxids(ctx context.Context, arg SelectVtxoPubKeysByCommitmentTxidsParams) ([]string, error) { + query := selectVtxoPubKeysByCommitmentTxids + var queryParams []interface{} + queryParams = append(queryParams, arg.MinAmount) + if len(arg.CommitmentTxids) > 0 { + for _, v := range arg.CommitmentTxids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:commitment_txids*/?", strings.Repeat(",?", len(arg.CommitmentTxids))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:commitment_txids*/?", "NULL", 1) + } + if len(arg.CommitmentTxidsAlt) > 0 { + for _, v := range arg.CommitmentTxidsAlt { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:commitment_txids_alt*/?", strings.Repeat(",?", len(arg.CommitmentTxidsAlt))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:commitment_txids_alt*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var pubkey string + if err := rows.Scan(&pubkey); err != nil { + return nil, err + } + items = append(items, pubkey) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const selectVtxosOutpointsByArkTxidRecursive = `-- name: SelectVtxosOutpointsByArkTxidRecursive :many WITH RECURSIVE descendants_chain AS ( -- seed diff --git a/internal/infrastructure/db/sqlite/sqlc/query.sql b/internal/infrastructure/db/sqlite/sqlc/query.sql index f4e2e9fe1..11b96d238 100644 --- a/internal/infrastructure/db/sqlite/sqlc/query.sql +++ b/internal/infrastructure/db/sqlite/sqlc/query.sql @@ -294,6 +294,29 @@ WHERE v.amount >= sqlc.arg('min_amount') AND (v.commitment_txid = sqlc.arg('commitment_txid') OR (',' || COALESCE(v.commitments, '') || ',') LIKE '%,' || sqlc.arg('commitment_txid') || ',%'); +-- Bulk variant of SelectVtxoPubKeysByCommitmentTxid: returns the +-- deduplicated set of vtxo pubkeys for any of the given commitment_txids. +-- Used at startup by restoreWatchingVtxos to collapse what was an N+1 +-- per-round loop into a single SQL call. +-- +-- Two slice placeholders bind the same list of txids: sqlc's sqlite +-- generator only rewrites the first occurrence of sqlc.slice('name') per +-- query, so checking both v.commitment_txid and the join table forces +-- the second IN clause to use a distinct slice name. The Go caller +-- passes the same []string to both. +-- name: SelectVtxoPubKeysByCommitmentTxids :many +SELECT DISTINCT v.pubkey +FROM vtxo v +WHERE v.amount >= sqlc.arg('min_amount') + AND ( + v.commitment_txid IN (sqlc.slice('commitment_txids')) + OR EXISTS ( + SELECT 1 FROM vtxo_commitment_txid vc + WHERE vc.vtxo_txid = v.txid AND vc.vtxo_vout = v.vout + AND vc.commitment_txid IN (sqlc.slice('commitment_txids_alt')) + ) + ); + -- name: SelectSweepableVtxoOutpointsByCommitmentTxid :many SELECT DISTINCT v.txid AS vtxo_txid, v.vout AS vtxo_vout FROM vtxo_vw v diff --git a/internal/infrastructure/db/sqlite/vtxo_repo.go b/internal/infrastructure/db/sqlite/vtxo_repo.go index fd86e1fc5..7eb279486 100644 --- a/internal/infrastructure/db/sqlite/vtxo_repo.go +++ b/internal/infrastructure/db/sqlite/vtxo_repo.go @@ -479,6 +479,98 @@ func (v *vtxoRepository) GetVtxoPubKeysByCommitmentTxid( return taprootKeys, nil } +// sqliteVtxoPubKeysBatchSize bounds the number of commitment_txids bound to +// a single SelectVtxoPubKeysByCommitmentTxids invocation. Because the query +// expands the same slice into two IN clauses, each call binds (1 + 2N) +// parameters: one for the amount filter, plus N for each of the two slice +// placeholders. modernc.org/sqlite caps total bound parameters at +// SQLITE_MAX_VARIABLE_NUMBER = 32766; with batchSize = 5000 a single call +// binds at most 10001 params, leaving generous headroom. The wrapper below +// splits the input slice into batches of this size and merges the +// deduplicated results in Go. +const sqliteVtxoPubKeysBatchSize = 5000 + +// GetVtxoPubKeysByCommitmentTxids is the bulk variant of +// GetVtxoPubKeysByCommitmentTxid. It returns the deduplicated set of vtxo +// pubkeys whose root commitment_txid is in the given list, or whose +// vtxo_commitment_txid join row references one of those commitment txids. +// This replaces a per-round loop in restoreWatchingVtxos / stopWatchingVtxos +// that previously fired one query per sweepable round (the N+1 pattern). +// +// The input slice is split into batches of sqliteVtxoPubKeysBatchSize to +// stay below the sqlite parameter cap (see the const doc above). For inputs +// up to 5000 txids the loop iterates once and the behaviour is identical +// to a single underlying query call. +// +// The generated SelectVtxoPubKeysByCommitmentTxidsParams struct exposes two +// CommitmentTxids* fields, but the public API of this method takes a single +// slice and binds it to both. Both fields MUST receive the same slice; the +// query template has two distinct slice placeholders only because sqlc's +// sqlite generator expands a slice placeholder only once per generated +// query, and the bulk query needs to look in two places (the root column +// and the join table). Passing different slices to the two fields would +// silently produce wrong results, so all calls to the generated method +// must go through this wrapper. +func (v *vtxoRepository) GetVtxoPubKeysByCommitmentTxids( + ctx context.Context, commitmentTxids []string, withMinimumAmount uint64, +) ([]string, error) { + return v.getVtxoPubKeysByCommitmentTxidsBatched( + ctx, commitmentTxids, withMinimumAmount, sqliteVtxoPubKeysBatchSize, + ) +} + +// getVtxoPubKeysByCommitmentTxidsBatched is the testable inner of +// GetVtxoPubKeysByCommitmentTxids that splits commitmentTxids into chunks of +// batchSize, issues one query per chunk, and merges the deduplicated union +// of results in Go. Kept as a method (not a free function) so the +// underlying querier is the live one from the repository. batchSize <= 0 is +// treated as "no batching" and runs a single call. This differs from +// chunkStrings (which panics on size <= 0) because the production caller +// always passes sqliteVtxoPubKeysBatchSize > 0; the permissive fallback +// only matters when tests reach this method directly with batchSize=0. +func (v *vtxoRepository) getVtxoPubKeysByCommitmentTxidsBatched( + ctx context.Context, + commitmentTxids []string, + withMinimumAmount uint64, + batchSize int, +) ([]string, error) { + if len(commitmentTxids) == 0 { + return nil, nil + } + if batchSize <= 0 { + batchSize = len(commitmentTxids) + } + + seen := make(map[string]struct{}) + for start := 0; start < len(commitmentTxids); start += batchSize { + end := start + batchSize + if end > len(commitmentTxids) { + end = len(commitmentTxids) + } + batch := commitmentTxids[start:end] + // Same slice in both fields by construction; see public method + // doc for the sqlc dual-placeholder explanation. + keys, err := v.querier.SelectVtxoPubKeysByCommitmentTxids(ctx, + queries.SelectVtxoPubKeysByCommitmentTxidsParams{ + MinAmount: int64(withMinimumAmount), + CommitmentTxids: batch, + CommitmentTxidsAlt: batch, + }) + if err != nil { + return nil, err + } + for _, k := range keys { + seen[k] = struct{}{} + } + } + + out := make([]string, 0, len(seen)) + for k := range seen { + out = append(out, k) + } + return out, nil +} + func (v *vtxoRepository) GetPendingSpentVtxosWithPubKeys( ctx context.Context, pubkeys []string, after, before int64, ) ([]domain.Vtxo, error) { diff --git a/internal/infrastructure/db/sqlite/vtxo_repo_batching_test.go b/internal/infrastructure/db/sqlite/vtxo_repo_batching_test.go new file mode 100644 index 000000000..7669fe421 --- /dev/null +++ b/internal/infrastructure/db/sqlite/vtxo_repo_batching_test.go @@ -0,0 +1,181 @@ +package sqlitedb_test + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + sqlitedb "github.com/arkade-os/arkd/internal/infrastructure/db/sqlite" + "github.com/stretchr/testify/require" +) + +// TestGetVtxoPubKeysByCommitmentTxidsBatched drives the multi-batch path +// of the unexported helper with small batch sizes against an in-memory +// sqlite DB. Guards against off-by-one errors in start/end slicing and +// missed dedup across batch boundaries. +func TestGetVtxoPubKeysByCommitmentTxidsBatched(t *testing.T) { + ctx := context.Background() + db, err := sqlitedb.OpenDb(":memory:") + require.NoError(t, err) + t.Cleanup(func() { + //nolint:errcheck + db.Close() + }) + + setupVtxoTables(t, db) + + // Seed seven commitment txids; each fans out to two vtxos. A handful + // of pubkeys appear under more than one commitment txid so the dedup + // path is exercised non-trivially. + const rounds = 7 + const vtxosPerRound = 2 + commitmentTxids := make([]string, 0, rounds) + expected := make(map[string]struct{}) + for r := 0; r < rounds; r++ { + commitmentTxid := fmt.Sprintf("commitment-%02d", r) + commitmentTxids = append(commitmentTxids, commitmentTxid) + for v := 0; v < vtxosPerRound; v++ { + pubkey := fmt.Sprintf("pubkey-%02d-%d", r, v) + vtxoTxid := fmt.Sprintf("vtxo-%02d-%d", r, v) + insertVtxoRow(t, db, vtxoTxid, v, pubkey, 1000, commitmentTxid) + // Cross-link every third vtxo to the previous round's + // commitment via the join table, so multiple batches can + // each return the same pubkey and the dedup logic has + // real work to do. + if r > 0 && v == 0 { + insertVtxoCommitmentTxidRow( + t, db, vtxoTxid, v, commitmentTxids[r-1], + ) + } + expected[pubkey] = struct{}{} + } + } + + repo, err := sqlitedb.NewVtxoRepository(db) + require.NoError(t, err) + + // 1, 2, 3 force the multi-batch loop; rounds-1 leaves a short tail + // batch; rounds and rounds+1 produce a single batch; 0 must fall + // through to the "no batching" branch. + for _, batchSize := range []int{1, 2, 3, rounds - 1, rounds, rounds + 1, 0} { + got, err := sqlitedb.GetVtxoPubKeysByCommitmentTxidsBatched( + ctx, repo, commitmentTxids, 0, batchSize, + ) + require.NoErrorf(t, err, "batchSize=%d", batchSize) + gotSet := make(map[string]struct{}, len(got)) + for _, k := range got { + gotSet[k] = struct{}{} + } + require.Equalf(t, len(got), len(gotSet), + "batchSize=%d: duplicates in result", batchSize) + require.Truef(t, reflect.DeepEqual(gotSet, expected), + "batchSize=%d: union mismatch (got %d unique, want %d)", + batchSize, len(gotSet), len(expected)) + } +} + +// TestGetVtxoPubKeysByCommitmentTxidsBatched_MinAmount verifies the +// withMinimumAmount predicate survives the per-batch query and merge. +func TestGetVtxoPubKeysByCommitmentTxidsBatched_MinAmount(t *testing.T) { + ctx := context.Background() + db, err := sqlitedb.OpenDb(":memory:") + require.NoError(t, err) + t.Cleanup(func() { + //nolint:errcheck + db.Close() + }) + + setupVtxoTables(t, db) + + // Two commitment txids, each with a below-threshold and an + // above-threshold vtxo. commitment-A also gets a vtxo whose amount + // equals min_amount to lock the inclusive >= predicate (the badger + // backend was previously > and is fixed in this PR for parity). + commitmentTxids := []string{"commitment-A", "commitment-B"} + insertVtxoRow(t, db, "vtxo-a-low", 0, "pubkey-a-low", 100, commitmentTxids[0]) + insertVtxoRow(t, db, "vtxo-a-eq", 0, "pubkey-a-eq", 1000, commitmentTxids[0]) + insertVtxoRow(t, db, "vtxo-a-high", 0, "pubkey-a-high", 5000, commitmentTxids[0]) + insertVtxoRow(t, db, "vtxo-b-low", 0, "pubkey-b-low", 200, commitmentTxids[1]) + insertVtxoRow(t, db, "vtxo-b-high", 0, "pubkey-b-high", 7500, commitmentTxids[1]) + + repo, err := sqlitedb.NewVtxoRepository(db) + require.NoError(t, err) + + got, err := sqlitedb.GetVtxoPubKeysByCommitmentTxidsBatched( + ctx, repo, commitmentTxids, 1000, 1, + ) + require.NoError(t, err) + require.ElementsMatch(t, + []string{"pubkey-a-eq", "pubkey-a-high", "pubkey-b-high"}, got) +} + +func setupVtxoTables(t *testing.T, db *sql.DB) { + t.Helper() + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS vtxo ( + txid TEXT NOT NULL, + vout INTEGER NOT NULL, + pubkey TEXT NOT NULL, + amount INTEGER NOT NULL, + expires_at INTEGER NOT NULL DEFAULT 0, + created_at INTEGER NOT NULL DEFAULT 0, + commitment_txid TEXT NOT NULL, + spent_by TEXT, + spent BOOLEAN NOT NULL DEFAULT FALSE, + unrolled BOOLEAN NOT NULL DEFAULT FALSE, + swept BOOLEAN NOT NULL DEFAULT FALSE, + preconfirmed BOOLEAN NOT NULL DEFAULT FALSE, + settled_by TEXT, + ark_txid TEXT, + intent_id TEXT, + PRIMARY KEY (txid, vout) + ); + `) + require.NoError(t, err, "create vtxo table") + + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS vtxo_commitment_txid ( + vtxo_txid TEXT NOT NULL, + vtxo_vout INTEGER NOT NULL, + commitment_txid TEXT NOT NULL, + PRIMARY KEY (vtxo_txid, vtxo_vout, commitment_txid) + ); + `) + require.NoError(t, err, "create vtxo_commitment_txid table") +} + +func insertVtxoRow( + t *testing.T, + db *sql.DB, + txid string, + vout int, + pubkey string, + amount int64, + commitmentTxid string, +) { + t.Helper() + _, err := db.Exec( + `INSERT INTO vtxo (txid, vout, pubkey, amount, commitment_txid) `+ + `VALUES (?, ?, ?, ?, ?)`, + txid, vout, pubkey, amount, commitmentTxid, + ) + require.NoError(t, err, "insert vtxo %s/%d", txid, vout) +} + +func insertVtxoCommitmentTxidRow( + t *testing.T, + db *sql.DB, + vtxoTxid string, + vtxoVout int, + commitmentTxid string, +) { + t.Helper() + _, err := db.Exec( + `INSERT INTO vtxo_commitment_txid (vtxo_txid, vtxo_vout, commitment_txid) `+ + `VALUES (?, ?, ?)`, + vtxoTxid, vtxoVout, commitmentTxid, + ) + require.NoError(t, err, "insert vtxo_commitment_txid %s/%d", vtxoTxid, vtxoVout) +} diff --git a/internal/infrastructure/db/vtxo_repo_bench_test.go b/internal/infrastructure/db/vtxo_repo_bench_test.go new file mode 100644 index 000000000..7681916e5 --- /dev/null +++ b/internal/infrastructure/db/vtxo_repo_bench_test.go @@ -0,0 +1,217 @@ +package db_test + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "testing" + "time" + + "github.com/arkade-os/arkd/internal/core/domain" + "github.com/arkade-os/arkd/internal/core/ports" + "github.com/arkade-os/arkd/internal/infrastructure/db" + "github.com/arkade-os/arkd/pkg/ark-lib/tree" + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/google/uuid" +) + +// Benchmark sizes for the vtxo-pubkey lookup. The per-txid variant (the +// pre-fix N+1 pattern) is intentionally capped lower than the bulk +// variant because a single iteration of the per-txid loop fans out into +// one DB call per round, so 10 000 rounds takes several minutes per +// iteration. The bulk variant collapses that to two calls and can +// comfortably handle the larger sizes. +var ( + perTxidSizes = []int{10, 100, 1000} + bulkSizes = []int{10, 100, 1000, 5000} +) + +// vtxosPerRound is the per-round vtxo fan-out we seed. Total scripts in +// the seeded DB is N * vtxosPerRound. +const vtxosPerRound = 10 + +// newBenchService spins up a fresh sqlite+badger RepoManager rooted at a +// throwaway temp dir. Each benchmark sub-run gets its own DB so iteration +// counts are not biased by warm caches across sizes. +func newBenchService(tb testing.TB) ports.RepoManager { + tb.Helper() + dir := tb.TempDir() + svc, err := db.NewService(db.ServiceConfig{ + EventStoreType: "badger", + DataStoreType: "sqlite", + EventStoreConfig: []interface{}{"", nil}, + DataStoreConfig: []interface{}{dir}, + }, nil) + if err != nil { + tb.Fatalf("open db: %s", err) + } + tb.Cleanup(svc.Close) + return svc +} + +// seedSweepableRounds inserts numRounds sweepable Round records and +// numRounds*vtxosPerRound vtxo rows. Returns the list of commitment txids +// that GetSweepableRounds will subsequently return so callers do not have +// to re-query for them on the hot path. +func seedSweepableRounds(tb testing.TB, svc ports.RepoManager, numRounds int) []string { + tb.Helper() + ctx := context.Background() + commitmentTxids := make([]string, 0, numRounds) + now := time.Now().Unix() + + for r := 0; r < numRounds; r++ { + commitmentTxid := randomHex(tb, 32) + round := domain.Round{ + Id: uuid.New().String(), + StartingTimestamp: now - 60, + EndingTimestamp: now, + Stage: domain.Stage{Code: int(domain.RoundFinalizationStage), Ended: true}, + Intents: map[string]domain.Intent{}, + CommitmentTxid: commitmentTxid, + CommitmentTx: "bench-commitment-tx", + Version: 1, + VtxoTreeExpiration: 100, + // One synthetic 'tree' tx is enough to satisfy the EXISTS + // clause in SelectSweepableRounds. + VtxoTree: tree.FlatTxTree{ + tree.TxTreeNode{Txid: randomHex(tb, 32), Tx: "bench-tree-tx"}, + }, + } + if err := svc.Rounds().AddOrUpdateRound(ctx, round); err != nil { + tb.Fatalf("AddOrUpdateRound[%d]: %s", r, err) + } + + vtxos := make([]domain.Vtxo, 0, vtxosPerRound) + for v := 0; v < vtxosPerRound; v++ { + vtxos = append(vtxos, domain.Vtxo{ + Outpoint: domain.Outpoint{Txid: randomHex(tb, 32), VOut: uint32(v)}, + Amount: 1000, + PubKey: randomXOnlyPubKey(tb), + CommitmentTxids: []string{commitmentTxid}, + RootCommitmentTxid: commitmentTxid, + CreatedAt: now, + ExpiresAt: now + 3600, + }) + } + if err := svc.Vtxos().AddVtxos(ctx, vtxos); err != nil { + tb.Fatalf("AddVtxos[%d]: %s", r, err) + } + commitmentTxids = append(commitmentTxids, commitmentTxid) + } + + return commitmentTxids +} + +// BenchmarkGetVtxoPubKeysByCommitmentTxid_PerTxidLoop reproduces the +// pre-fix code path: for each sweepable round, issue a singular +// per-commitment-txid query and union the results. This is what +// restoreWatchingVtxos used to do. +func BenchmarkGetVtxoPubKeysByCommitmentTxid_PerTxidLoop(b *testing.B) { + for _, n := range perTxidSizes { + b.Run(fmt.Sprintf("rounds=%d", n), func(b *testing.B) { + svc := newBenchService(b) + txids := seedSweepableRounds(b, svc, n) + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + total := 0 + for _, t := range txids { + keys, err := svc.Vtxos(). + GetVtxoPubKeysByCommitmentTxid(ctx, t, 0) + if err != nil { + b.Fatalf("singular: %s", err) + } + total += len(keys) + } + if total == 0 { + b.Fatalf("expected non-zero keys at rounds=%d", n) + } + } + }) + } +} + +// BenchmarkGetVtxoPubKeysByCommitmentTxids_Bulk exercises the bulk +// variant that replaces the loop above. The bulk method runs one SQL +// query regardless of how many commitment txids it is given. +func BenchmarkGetVtxoPubKeysByCommitmentTxids_Bulk(b *testing.B) { + for _, n := range bulkSizes { + b.Run(fmt.Sprintf("rounds=%d", n), func(b *testing.B) { + svc := newBenchService(b) + txids := seedSweepableRounds(b, svc, n) + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + keys, err := svc.Vtxos(). + GetVtxoPubKeysByCommitmentTxids(ctx, txids, 0) + if err != nil { + b.Fatalf("bulk: %s", err) + } + if len(keys) == 0 { + b.Fatalf("expected non-zero keys at rounds=%d", n) + } + } + }) + } +} + +// TestVtxoPubKeysBulkMatchesLoop asserts that the bulk query returns the +// same deduplicated pubkey set as the singular per-txid loop, so the +// benchmark comparison is apples to apples. Uses a modest size since it +// is part of the standard test suite and runs on every CI cycle. +func TestVtxoPubKeysBulkMatchesLoop(t *testing.T) { + svc := newBenchService(t) + txids := seedSweepableRounds(t, svc, 50) + + ctx := context.Background() + loopUnion := make(map[string]struct{}) + for _, txid := range txids { + keys, err := svc.Vtxos().GetVtxoPubKeysByCommitmentTxid(ctx, txid, 0) + if err != nil { + t.Fatalf("singular: %s", err) + } + for _, k := range keys { + loopUnion[k] = struct{}{} + } + } + + bulk, err := svc.Vtxos().GetVtxoPubKeysByCommitmentTxids(ctx, txids, 0) + if err != nil { + t.Fatalf("bulk: %s", err) + } + bulkSet := make(map[string]struct{}, len(bulk)) + for _, k := range bulk { + bulkSet[k] = struct{}{} + } + + if len(bulkSet) != len(loopUnion) { + t.Fatalf("bulk set size=%d loop union size=%d", len(bulkSet), len(loopUnion)) + } + for k := range loopUnion { + if _, ok := bulkSet[k]; !ok { + t.Fatalf("bulk missing pubkey %s", k) + } + } +} + +// randomHex returns 2*n hex characters. +func randomHex(tb testing.TB, n int) string { + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + tb.Fatalf("rand: %s", err) + } + return hex.EncodeToString(buf) +} + +// randomXOnlyPubKey returns a fresh schnorr x-only pubkey as a 64-char +// hex string. Each vtxo gets a unique pubkey so the bulk DISTINCT path +// is exercised non-trivially. +func randomXOnlyPubKey(tb testing.TB) string { + priv, err := btcec.NewPrivateKey() + if err != nil { + tb.Fatalf("priv: %s", err) + } + return hex.EncodeToString(schnorr.SerializePubKey(priv.PubKey())) +} diff --git a/internal/infrastructure/wallet/wallet_client.go b/internal/infrastructure/wallet/wallet_client.go index 1f4554197..e49e5b565 100644 --- a/internal/infrastructure/wallet/wallet_client.go +++ b/internal/infrastructure/wallet/wallet_client.go @@ -25,6 +25,11 @@ import ( type walletDaemonClient struct { client arkwalletv1.WalletServiceClient conn *grpc.ClientConn + // chunkSize bounds how many scripts are sent per WatchScripts / + // UnwatchScripts gRPC call. Zero means use defaultWatchScriptsChunkSize. + // Only set explicitly by tests; production callers go through New() and + // always get the default. + chunkSize int } // New creates a ports.WalletService backed by a gRPC client. @@ -99,14 +104,81 @@ func (w *walletDaemonClient) GetTransaction(ctx context.Context, txid string) (s return resp.GetTxHex(), nil } +// defaultWatchScriptsChunkSize bounds the number of scripts sent in a +// single WatchScripts / UnwatchScripts gRPC call when the caller has not +// configured an override. Each script is a hex-encoded taproot output +// (68 bytes) plus protobuf overhead, so 2000 scripts is roughly 150 KiB, +// well under the default gRPC 4 MiB message cap. +const defaultWatchScriptsChunkSize = 2000 + +// effectiveChunkSize returns the chunk size this client should use, +// falling back to the package default if no explicit size was set. +func (w *walletDaemonClient) effectiveChunkSize() int { + if w.chunkSize > 0 { + return w.chunkSize + } + return defaultWatchScriptsChunkSize +} + +// chunkStrings splits in into groups of at most size elements. The +// returned slices share backing storage with in, so callers must not +// mutate the input until they are done iterating. Panics on size <= 0 +// because the caller is the one in control of the size (it is a +// programming error to pass a non-positive value here) and silently +// returning the whole slice as one chunk would defeat the purpose of +// chunking. +func chunkStrings(in []string, size int) [][]string { + if size <= 0 { + panic(fmt.Sprintf("chunkStrings: size must be > 0, got %d", size)) + } + if len(in) == 0 { + return nil + } + chunks := make([][]string, 0, (len(in)+size-1)/size) + for i := 0; i < len(in); i += size { + end := i + size + if end > len(in) { + end = len(in) + } + chunks = append(chunks, in[i:end]) + } + return chunks +} + +// WatchScripts registers the given scripts with the wallet daemon. The +// scripts list is split into chunks of effectiveChunkSize() and sent as +// sequential gRPC calls so the request payload stays below the default +// 4 MiB gRPC max-message size at very large script counts (eg. boot-time +// restore of every tap key across all sweepable rounds). func (w *walletDaemonClient) WatchScripts(ctx context.Context, scripts []string) error { - _, err := w.client.WatchScripts(ctx, &arkwalletv1.WatchScriptsRequest{Scripts: scripts}) - return err + if len(scripts) == 0 { + return nil + } + for _, chunk := range chunkStrings(scripts, w.effectiveChunkSize()) { + _, err := w.client.WatchScripts( + ctx, &arkwalletv1.WatchScriptsRequest{Scripts: chunk}, + ) + if err != nil { + return err + } + } + return nil } +// UnwatchScripts is chunked for the same reason as WatchScripts. func (w *walletDaemonClient) UnwatchScripts(ctx context.Context, scripts []string) error { - _, err := w.client.UnwatchScripts(ctx, &arkwalletv1.UnwatchScriptsRequest{Scripts: scripts}) - return err + if len(scripts) == 0 { + return nil + } + for _, chunk := range chunkStrings(scripts, w.effectiveChunkSize()) { + _, err := w.client.UnwatchScripts( + ctx, &arkwalletv1.UnwatchScriptsRequest{Scripts: chunk}, + ) + if err != nil { + return err + } + } + return nil } func (w *walletDaemonClient) SignMessage(ctx context.Context, message []byte) ([]byte, error) { diff --git a/internal/infrastructure/wallet/wallet_client_test.go b/internal/infrastructure/wallet/wallet_client_test.go new file mode 100644 index 000000000..476f24d83 --- /dev/null +++ b/internal/infrastructure/wallet/wallet_client_test.go @@ -0,0 +1,244 @@ +package walletclient + +import ( + "context" + "errors" + "fmt" + "strconv" + "testing" + + arkwalletv1 "github.com/arkade-os/arkd/api-spec/protobuf/gen/arkwallet/v1" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +// fakeWalletClient stands in for an arkwalletv1.WalletServiceClient in tests. +// Embedding the interface (with a nil value) means any method we do not +// override will panic, which is fine since each test only drives one or two +// methods. WatchScripts and UnwatchScripts are overridden to record the +// chunks they receive and optionally fail on a chosen call. +type fakeWalletClient struct { + arkwalletv1.WalletServiceClient + watchCalls [][]string + unwatchCalls [][]string + failOnCallIdx int + failErr error +} + +func (f *fakeWalletClient) WatchScripts( + _ context.Context, in *arkwalletv1.WatchScriptsRequest, _ ...grpc.CallOption, +) (*arkwalletv1.WatchScriptsResponse, error) { + // Copy so callers can mutate without disturbing recorded state. + recorded := append([]string(nil), in.Scripts...) + f.watchCalls = append(f.watchCalls, recorded) + if f.failErr != nil && len(f.watchCalls) == f.failOnCallIdx { + return nil, f.failErr + } + return &arkwalletv1.WatchScriptsResponse{}, nil +} + +func (f *fakeWalletClient) UnwatchScripts( + _ context.Context, in *arkwalletv1.UnwatchScriptsRequest, _ ...grpc.CallOption, +) (*arkwalletv1.UnwatchScriptsResponse, error) { + recorded := append([]string(nil), in.Scripts...) + f.unwatchCalls = append(f.unwatchCalls, recorded) + if f.failErr != nil && len(f.unwatchCalls) == f.failOnCallIdx { + return nil, f.failErr + } + return &arkwalletv1.UnwatchScriptsResponse{}, nil +} + +// newTestClient returns a walletDaemonClient bound to fake with the given +// chunk size. Construct one per test so the chunk size never escapes the +// test boundary and the tests can run in parallel safely. +func newTestClient(fake *fakeWalletClient, chunkSize int) *walletDaemonClient { + return &walletDaemonClient{client: fake, chunkSize: chunkSize} +} + +func makeScripts(n int) []string { + out := make([]string, n) + for i := 0; i < n; i++ { + // Content is irrelevant for chunking; the index is enough to assert + // the order of recorded chunks. + out[i] = "s" + strconv.Itoa(i) + } + return out +} + +func TestChunkStrings(t *testing.T) { + tests := []struct { + name string + in []string + size int + want [][]string + }{ + { + name: "nil_input", + in: nil, + size: 100, + want: nil, + }, + { + name: "empty_input", + in: []string{}, + size: 100, + want: nil, + }, + { + name: "single_full_chunk", + in: []string{"a", "b", "c"}, + size: 10, + want: [][]string{{"a", "b", "c"}}, + }, + { + name: "exact_multiple", + in: []string{"a", "b", "c", "d"}, + size: 2, + want: [][]string{{"a", "b"}, {"c", "d"}}, + }, + { + name: "uneven_last_chunk", + in: []string{"a", "b", "c", "d", "e"}, + size: 2, + want: [][]string{{"a", "b"}, {"c", "d"}, {"e"}}, + }, + { + name: "size_one", + in: []string{"a", "b"}, + size: 1, + want: [][]string{{"a"}, {"b"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := chunkStrings(tt.in, tt.size) + require.Equal(t, tt.want, got) + }) + } +} + +// TestChunkStringsBadSizePanics asserts that a non-positive size triggers a +// panic. The wrapper effectiveChunkSize() prevents this in production, so +// the only way to reach it is a programmer bug, which deserves a loud +// failure rather than a silent single-chunk fallback that would defeat +// the entire point of chunking. +func TestChunkStringsBadSizePanics(t *testing.T) { + for _, size := range []int{0, -1, -1000} { + t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) { + require.PanicsWithValue( + t, + fmt.Sprintf("chunkStrings: size must be > 0, got %d", size), + func() { _ = chunkStrings([]string{"a"}, size) }, + ) + }) + } +} + +func TestWalletClientWatchScriptsChunking(t *testing.T) { + t.Run("empty_input_no_calls", func(t *testing.T) { + fake := &fakeWalletClient{} + c := newTestClient(fake, 0) + require.NoError(t, c.WatchScripts(context.Background(), nil)) + require.NoError(t, c.WatchScripts(context.Background(), []string{})) + require.Empty(t, fake.watchCalls) + }) + + t.Run("single_chunk_when_under_limit", func(t *testing.T) { + fake := &fakeWalletClient{} + c := newTestClient(fake, 100) + scripts := makeScripts(75) + require.NoError(t, c.WatchScripts(context.Background(), scripts)) + require.Len(t, fake.watchCalls, 1) + require.Equal(t, scripts, fake.watchCalls[0]) + }) + + t.Run("exact_chunk_boundary", func(t *testing.T) { + fake := &fakeWalletClient{} + c := newTestClient(fake, 100) + scripts := makeScripts(100) + require.NoError(t, c.WatchScripts(context.Background(), scripts)) + require.Len(t, fake.watchCalls, 1) + require.Len(t, fake.watchCalls[0], 100) + }) + + t.Run("splits_above_boundary", func(t *testing.T) { + fake := &fakeWalletClient{} + c := newTestClient(fake, 100) + scripts := makeScripts(101) + require.NoError(t, c.WatchScripts(context.Background(), scripts)) + require.Len(t, fake.watchCalls, 2) + require.Len(t, fake.watchCalls[0], 100) + require.Len(t, fake.watchCalls[1], 1) + }) + + t.Run("large_input_round_trips_intact", func(t *testing.T) { + fake := &fakeWalletClient{} + c := newTestClient(fake, 250) + scripts := makeScripts(1000) + require.NoError(t, c.WatchScripts(context.Background(), scripts)) + // Expect 4 chunks of 250 each. + require.Len(t, fake.watchCalls, 4) + for _, c := range fake.watchCalls { + require.Len(t, c, 250) + } + // Reassemble and confirm order plus completeness. + var reassembled []string + for _, c := range fake.watchCalls { + reassembled = append(reassembled, c...) + } + require.Equal(t, scripts, reassembled) + }) + + t.Run("error_on_middle_chunk_short_circuits", func(t *testing.T) { + boom := errors.New("simulated grpc failure") + fake := &fakeWalletClient{failOnCallIdx: 3, failErr: boom} + c := newTestClient(fake, 10) + err := c.WatchScripts(context.Background(), makeScripts(100)) + require.ErrorIs(t, err, boom) + // Three chunks attempted, the third returned the error. No further + // calls should fire. + require.Len(t, fake.watchCalls, 3) + }) + + t.Run("default_chunk_size_used_when_unset", func(t *testing.T) { + fake := &fakeWalletClient{} + // chunkSize=0 falls back to defaultWatchScriptsChunkSize. Drive + // just enough scripts to land exactly two chunks at the default + // to confirm the fallback path actually fires. + c := newTestClient(fake, 0) + require.NoError( + t, + c.WatchScripts(context.Background(), makeScripts(defaultWatchScriptsChunkSize+1)), + ) + require.Len(t, fake.watchCalls, 2) + require.Len(t, fake.watchCalls[0], defaultWatchScriptsChunkSize) + require.Len(t, fake.watchCalls[1], 1) + }) +} + +func TestWalletClientUnwatchScriptsChunking(t *testing.T) { + t.Run("empty_input_no_calls", func(t *testing.T) { + fake := &fakeWalletClient{} + c := newTestClient(fake, 0) + require.NoError(t, c.UnwatchScripts(context.Background(), nil)) + require.Empty(t, fake.unwatchCalls) + }) + + t.Run("splits_above_boundary", func(t *testing.T) { + fake := &fakeWalletClient{} + c := newTestClient(fake, 50) + require.NoError(t, c.UnwatchScripts(context.Background(), makeScripts(151))) + require.Len(t, fake.unwatchCalls, 4) + require.Len(t, fake.unwatchCalls[3], 1) + }) + + t.Run("error_propagates", func(t *testing.T) { + boom := fmt.Errorf("nope") + fake := &fakeWalletClient{failOnCallIdx: 1, failErr: boom} + c := newTestClient(fake, 5) + err := c.UnwatchScripts(context.Background(), makeScripts(20)) + require.ErrorIs(t, err, boom) + require.Len(t, fake.unwatchCalls, 1) + }) +}