From 3a09cdb2c7974090a31045a89258bc8c6b8ccb15 Mon Sep 17 00:00:00 2001 From: Julio Cesar Date: Wed, 3 Jun 2026 22:21:31 +0200 Subject: [PATCH] payments: count total with date filters --- docs/release-notes/release-notes-0.22.0.md | 4 + payments/db/kv_store.go | 141 ++++++++++++++++++--- payments/db/payment_test.go | 23 ++++ payments/db/sql_store.go | 49 +++---- sqldb/sqlc/payments.sql.go | 29 +++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/payments.sql | 11 ++ 7 files changed, 219 insertions(+), 39 deletions(-) diff --git a/docs/release-notes/release-notes-0.22.0.md b/docs/release-notes/release-notes-0.22.0.md index 45e76051922..ff08d579c0f 100644 --- a/docs/release-notes/release-notes-0.22.0.md +++ b/docs/release-notes/release-notes-0.22.0.md @@ -34,6 +34,10 @@ the close transaction is actually broadcast, and `WaitingCloseChannel.ClosingTx` is never empty. +* [Fixed `ListPayments`](https://github.com/lightningnetwork/lnd/pull/10874) + so `count_total_payments` respects the `creation_date_start` and + `creation_date_end` filters. + # New Features ## Functional Enhancements diff --git a/payments/db/kv_store.go b/payments/db/kv_store.go index b3fffe1e182..8ca40d9aa7b 100644 --- a/payments/db/kv_store.go +++ b/payments/db/kv_store.go @@ -1053,6 +1053,26 @@ func (p *KVStore) QueryPayments(_ context.Context, var resp Response + matchesCreationDate := func(creationTime time.Time) bool { + // Get the creation time in Unix seconds, this always rounds down + // the nanoseconds to full seconds. + createTime := creationTime.Unix() + + // Skip any payments that were created before the specified time. + if createTime < query.CreationDateStart { + return false + } + + // Skip any payments that were created after the specified time. + if query.CreationDateEnd != 0 && + createTime > query.CreationDateEnd { + + return false + } + + return true + } + if err := kvdb.View(p.db, func(tx kvdb.RTx) error { // Get the root payments bucket. paymentsBucket := tx.ReadBucket(paymentsRootBucket) @@ -1096,21 +1116,7 @@ func (p *KVStore) QueryPayments(_ context.Context, return false, err } - // Get the creation time in Unix seconds, this always - // rounds down the nanoseconds to full seconds. - createTime := payment.Info.CreationTime.Unix() - - // Skip any payments that were created before the - // specified time. - if createTime < query.CreationDateStart { - return false, nil - } - - // Skip any payments that were created after the - // specified time. - if query.CreationDateEnd != 0 && - createTime > query.CreationDateEnd { - + if !matchesCreationDate(payment.Info.CreationTime) { return false, nil } @@ -1141,7 +1147,28 @@ func (p *KVStore) QueryPayments(_ context.Context, totalPayments uint64 err error ) - countFn := func(_, _ []byte) error { + countFn := func(sequenceKey, hash []byte) error { + if query.CreationDateStart != 0 || + query.CreationDateEnd != 0 { + + r := bytes.NewReader(hash) + paymentHash, err := deserializePaymentIndex(r) + if err != nil { + return err + } + + creationTime, err := fetchCreationTimeWithSequenceNumber( + tx, paymentHash, sequenceKey, + ) + if err != nil { + return err + } + + if !matchesCreationDate(creationTime) { + return nil + } + } + totalPayments++ return nil @@ -1189,6 +1216,88 @@ func (p *KVStore) QueryPayments(_ context.Context, return resp, nil } +// fetchCreationTimeWithSequenceNumber gets the creation time for the payment +// that matches the payment hash and sequence number. This is the lightweight +// counterpart to fetchPaymentWithSequenceNumber for callers that only need the +// query metadata and should avoid loading all HTLC attempts. +func fetchCreationTimeWithSequenceNumber(tx kvdb.RTx, + paymentHash lntypes.Hash, sequenceNumber []byte) (time.Time, error) { + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err != nil { + return time.Time{}, err + } + + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return time.Time{}, ErrNoSequenceNumber + } + + if bytes.Equal(seqBytes, sequenceNumber) { + creationInfo, err := fetchCreationInfo(bucket) + if err != nil { + return time.Time{}, err + } + + return creationInfo.CreationTime, nil + } + + dup := bucket.NestedReadBucket(duplicatePaymentsBucket) + if dup == nil { + return time.Time{}, ErrNoDuplicateBucket + } + + var ( + creationTime time.Time + found bool + ) + err = dup.ForEach(func(k, _ []byte) error { + if found { + return nil + } + + subBucket := dup.NestedReadBucket(k) + if subBucket == nil { + return ErrNoDuplicateNestedBucket + } + + seqBytes := subBucket.Get(duplicatePaymentSequenceKey) + if seqBytes == nil { + return ErrNoDuplicateSequenceNumber + } + + if !bytes.Equal(seqBytes, sequenceNumber) { + return nil + } + + b := subBucket.Get(duplicatePaymentCreationInfoKey) + if b == nil { + return fmt.Errorf("creation info not found") + } + + creationInfo, err := deserializeDuplicatePaymentCreationInfo( + bytes.NewReader(b), + ) + if err != nil { + return err + } + + creationTime = creationInfo.CreationTime + found = true + + return nil + }) + if err != nil { + return time.Time{}, err + } + + if !found { + return time.Time{}, ErrDuplicateNotFound + } + + return creationTime, nil +} + // fetchPaymentWithSequenceNumber get the payment which matches the payment hash // *and* sequence number provided from the database. This is required because // we previously had more than one payment per hash, so we have multiple indexes diff --git a/payments/db/payment_test.go b/payments/db/payment_test.go index e304b136dd2..6a3f83fce59 100644 --- a/payments/db/payment_test.go +++ b/payments/db/payment_test.go @@ -2424,6 +2424,10 @@ func TestQueryPayments(t *testing.T) { firstIndex uint64 lastIndex uint64 + // expectedTotal is the expected TotalCount when CountTotal is + // set. A zero value means the unfiltered total is expected. + expectedTotal uint64 + // expectedSeqNrs contains the set of sequence numbers we expect // our query to return. expectedSeqNrs []uint64 @@ -2717,6 +2721,22 @@ func TestQueryPayments(t *testing.T) { lastIndex: 6, expectedSeqNrs: []uint64{6}, }, + { + name: "count total with date filters", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CountTotal: true, + CreationDateStart: 3, + CreationDateEnd: 5, + }, + firstIndex: 3, + lastIndex: 4, + expectedTotal: 3, + expectedSeqNrs: []uint64{3, 4}, + }, } for _, tt := range tests { @@ -2870,6 +2890,9 @@ func TestQueryPayments(t *testing.T) { // We should have 5 total payments // (6 created - 1 deleted). expectedTotal := uint64(5) + if tt.expectedTotal != 0 { + expectedTotal = tt.expectedTotal + } require.Equal( t, expectedTotal, querySlice.TotalCount, "expected total count %v, got %v", diff --git a/payments/db/sql_store.go b/payments/db/sql_store.go index 3d92385fd0f..95e8852a819 100644 --- a/payments/db/sql_store.go +++ b/payments/db/sql_store.go @@ -52,6 +52,7 @@ type SQLQueries interface { FetchNonTerminalPayments(ctx context.Context, arg sqlc.FetchNonTerminalPaymentsParams) ([]sqlc.FetchNonTerminalPaymentsRow, error) CountPayments(ctx context.Context) (int64, error) + CountFilteredPayments(ctx context.Context, query sqlc.CountFilteredPaymentsParams) (int64, error) FetchHtlcAttemptsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptsForPaymentsRow, error) FetchHtlcAttemptResolutionsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptResolutionsForPaymentsRow, error) @@ -690,11 +691,35 @@ func (s *SQLStore) QueryPayments(ctx context.Context, query Query) (Response, return row.Payment.ID } + // Default date bounds: epoch start and far future. These are always + // provided so the SQL query uses simple comparisons instead of COALESCE + // (which causes type mismatch on Postgres) or OR-based optional filters + // (which can prevent index usage). + createdAfter := time.Unix(0, 0).UTC() + if query.CreationDateStart != 0 { + createdAfter = time.Unix(query.CreationDateStart, 0).UTC() + } + + createdBefore := time.Date( + 9999, 12, 31, 23, 59, 59, 0, time.UTC, + ) + if query.CreationDateEnd != 0 { + createdBefore = time.Unix(query.CreationDateEnd, 0).UTC() + } + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { // We first count all payments to determine the total count // if requested. if query.CountTotal { - totalPayments, err := db.CountPayments(ctx) + totalPayments, err := db.CountFilteredPayments( + ctx, sqlc.CountFilteredPaymentsParams{ + CreatedAfter: createdAfter, + CreatedBefore: createdBefore, + IntentType: sqldb.SQLInt16( + PaymentIntentTypeBolt11, + ), + }, + ) if err != nil { return fmt.Errorf("failed to count "+ "payments: %w", err) @@ -768,28 +793,6 @@ func (s *SQLStore) QueryPayments(ctx context.Context, query Query) (Response, queryFunc := func(ctx context.Context, lastID int64, limit int32) ([]sqlc.FilterPaymentsRow, error) { - // Default date bounds: epoch start and far - // future. These are always provided so the SQL - // query uses simple comparisons instead of - // COALESCE (which causes type mismatch on - // Postgres) or OR-based optional filters (which - // can prevent index usage). - createdAfter := time.Unix(0, 0).UTC() - if query.CreationDateStart != 0 { - createdAfter = time.Unix( - query.CreationDateStart, 0, - ).UTC() - } - - createdBefore := time.Date( - 9999, 12, 31, 23, 59, 59, 0, time.UTC, - ) - if query.CreationDateEnd != 0 { - createdBefore = time.Unix( - query.CreationDateEnd, 0, - ).UTC() - } - filterParams := sqlc.FilterPaymentsParams{ NumLimit: limit, CreatedAfter: createdAfter, diff --git a/sqldb/sqlc/payments.sql.go b/sqldb/sqlc/payments.sql.go index 42a8fb826f0..ace1b5144d2 100644 --- a/sqldb/sqlc/payments.sql.go +++ b/sqldb/sqlc/payments.sql.go @@ -23,6 +23,35 @@ func (q *Queries) CountPayments(ctx context.Context) (int64, error) { return count, err } +const countFilteredPayments = `-- name: CountFilteredPayments :one +SELECT COUNT(*) +FROM payments p +LEFT JOIN payment_intents i ON i.payment_id = p.id +WHERE p.created_at >= $1 + AND p.created_at <= $2 + AND ( + i.intent_type = $3 OR + $3 IS NULL OR i.intent_type IS NULL + ) +` + +type CountFilteredPaymentsParams struct { + CreatedAfter time.Time + CreatedBefore time.Time + IntentType sql.NullInt16 +} + +func (q *Queries) CountFilteredPayments(ctx context.Context, arg CountFilteredPaymentsParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countFilteredPayments, + arg.CreatedAfter, + arg.CreatedBefore, + arg.IntentType, + ) + var count int64 + err := row.Scan(&count) + return count, err +} + const deleteFailedAttempts = `-- name: DeleteFailedAttempts :exec DELETE FROM payment_htlc_attempts WHERE payment_id = $1 diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 9b95a669917..7497cf7f728 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -15,6 +15,7 @@ type Querier interface { AddV1ChannelProof(ctx context.Context, arg AddV1ChannelProofParams) (sql.Result, error) AddV2ChannelProof(ctx context.Context, arg AddV2ChannelProofParams) (sql.Result, error) ClearKVInvoiceHashIndex(ctx context.Context) error + CountFilteredPayments(ctx context.Context, arg CountFilteredPaymentsParams) (int64, error) CountPayments(ctx context.Context) (int64, error) CountZombieChannels(ctx context.Context, version int16) (int64, error) CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error) diff --git a/sqldb/sqlc/queries/payments.sql b/sqldb/sqlc/queries/payments.sql index 16682c83bc3..bc3f3b1ab0c 100644 --- a/sqldb/sqlc/queries/payments.sql +++ b/sqldb/sqlc/queries/payments.sql @@ -47,6 +47,17 @@ WHERE p.id > COALESCE(sqlc.narg('index_offset_get'), -1) ORDER BY p.id DESC LIMIT @num_limit; +-- name: CountFilteredPayments :one +SELECT COUNT(*) +FROM payments p +LEFT JOIN payment_intents i ON i.payment_id = p.id +WHERE p.created_at >= @created_after + AND p.created_at <= @created_before + AND ( + i.intent_type = sqlc.narg('intent_type') OR + sqlc.narg('intent_type') IS NULL OR i.intent_type IS NULL + ); + -- name: FetchPayment :one SELECT sqlc.embed(p),