Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/release-notes/release-notes-0.22.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
the close transaction is actually broadcast, and
`WaitingCloseChannel.ClosingTx` is never empty.

* [Fixed `ListPayments`](https://github.com/lightningnetwork/lnd/pull/10874)
so `count_total_payments` respects the `creation_date_start` and
`creation_date_end` filters.

# New Features

## Functional Enhancements
Expand Down
141 changes: 125 additions & 16 deletions payments/db/kv_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,26 @@ func (p *KVStore) QueryPayments(_ context.Context,

var resp Response

matchesCreationDate := func(creationTime time.Time) bool {
// Get the creation time in Unix seconds, this always rounds down
// the nanoseconds to full seconds.
createTime := creationTime.Unix()

// Skip any payments that were created before the specified time.
if createTime < query.CreationDateStart {
return false
}

// Skip any payments that were created after the specified time.
if query.CreationDateEnd != 0 &&
createTime > query.CreationDateEnd {

return false
}

return true
}

if err := kvdb.View(p.db, func(tx kvdb.RTx) error {
// Get the root payments bucket.
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
Expand Down Expand Up @@ -1096,21 +1116,7 @@ func (p *KVStore) QueryPayments(_ context.Context,
return false, err
}

// Get the creation time in Unix seconds, this always
// rounds down the nanoseconds to full seconds.
createTime := payment.Info.CreationTime.Unix()

// Skip any payments that were created before the
// specified time.
if createTime < query.CreationDateStart {
return false, nil
}

// Skip any payments that were created after the
// specified time.
if query.CreationDateEnd != 0 &&
createTime > query.CreationDateEnd {

if !matchesCreationDate(payment.Info.CreationTime) {
return false, nil
}

Expand Down Expand Up @@ -1141,7 +1147,28 @@ func (p *KVStore) QueryPayments(_ context.Context,
totalPayments uint64
err error
)
countFn := func(_, _ []byte) error {
countFn := func(sequenceKey, hash []byte) error {
if query.CreationDateStart != 0 ||
query.CreationDateEnd != 0 {

r := bytes.NewReader(hash)
paymentHash, err := deserializePaymentIndex(r)
if err != nil {
return err
}

creationTime, err := fetchCreationTimeWithSequenceNumber(
tx, paymentHash, sequenceKey,
)
if err != nil {
return err
}

if !matchesCreationDate(creationTime) {
return nil
}
}

totalPayments++

return nil
Expand Down Expand Up @@ -1189,6 +1216,88 @@ func (p *KVStore) QueryPayments(_ context.Context,
return resp, nil
}

// fetchCreationTimeWithSequenceNumber gets the creation time for the payment
// that matches the payment hash and sequence number. This is the lightweight
// counterpart to fetchPaymentWithSequenceNumber for callers that only need the
// query metadata and should avoid loading all HTLC attempts.
func fetchCreationTimeWithSequenceNumber(tx kvdb.RTx,
paymentHash lntypes.Hash, sequenceNumber []byte) (time.Time, error) {

bucket, err := fetchPaymentBucket(tx, paymentHash)
if err != nil {
return time.Time{}, err
}

seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes == nil {
return time.Time{}, ErrNoSequenceNumber
}

if bytes.Equal(seqBytes, sequenceNumber) {
creationInfo, err := fetchCreationInfo(bucket)
if err != nil {
return time.Time{}, err
}

return creationInfo.CreationTime, nil
}

dup := bucket.NestedReadBucket(duplicatePaymentsBucket)
if dup == nil {
return time.Time{}, ErrNoDuplicateBucket
}

var (
creationTime time.Time
found bool
)
err = dup.ForEach(func(k, _ []byte) error {
if found {
return nil
}

subBucket := dup.NestedReadBucket(k)
if subBucket == nil {
return ErrNoDuplicateNestedBucket
}

seqBytes := subBucket.Get(duplicatePaymentSequenceKey)
if seqBytes == nil {
return ErrNoDuplicateSequenceNumber
}

if !bytes.Equal(seqBytes, sequenceNumber) {
return nil
}

b := subBucket.Get(duplicatePaymentCreationInfoKey)
if b == nil {
return fmt.Errorf("creation info not found")
}

creationInfo, err := deserializeDuplicatePaymentCreationInfo(
bytes.NewReader(b),
)
if err != nil {
return err
}

creationTime = creationInfo.CreationTime
found = true

return nil
})
if err != nil {
return time.Time{}, err
}

if !found {
return time.Time{}, ErrDuplicateNotFound
}

return creationTime, nil
}

// fetchPaymentWithSequenceNumber get the payment which matches the payment hash
// *and* sequence number provided from the database. This is required because
// we previously had more than one payment per hash, so we have multiple indexes
Expand Down
23 changes: 23 additions & 0 deletions payments/db/payment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,10 @@ func TestQueryPayments(t *testing.T) {
firstIndex uint64
lastIndex uint64

// expectedTotal is the expected TotalCount when CountTotal is
// set. A zero value means the unfiltered total is expected.
expectedTotal uint64

// expectedSeqNrs contains the set of sequence numbers we expect
// our query to return.
expectedSeqNrs []uint64
Expand Down Expand Up @@ -2717,6 +2721,22 @@ func TestQueryPayments(t *testing.T) {
lastIndex: 6,
expectedSeqNrs: []uint64{6},
},
{
name: "count total with date filters",
query: Query{
IndexOffset: 0,
MaxPayments: 2,
Reversed: false,
IncludeIncomplete: true,
CountTotal: true,
CreationDateStart: 3,
CreationDateEnd: 5,
},
firstIndex: 3,
lastIndex: 4,
expectedTotal: 3,
expectedSeqNrs: []uint64{3, 4},
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -2870,6 +2890,9 @@ func TestQueryPayments(t *testing.T) {
// We should have 5 total payments
// (6 created - 1 deleted).
expectedTotal := uint64(5)
if tt.expectedTotal != 0 {
expectedTotal = tt.expectedTotal
}
require.Equal(
t, expectedTotal, querySlice.TotalCount,
"expected total count %v, got %v",
Expand Down
49 changes: 26 additions & 23 deletions payments/db/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type SQLQueries interface {
FetchNonTerminalPayments(ctx context.Context, arg sqlc.FetchNonTerminalPaymentsParams) ([]sqlc.FetchNonTerminalPaymentsRow, error)

CountPayments(ctx context.Context) (int64, error)
CountFilteredPayments(ctx context.Context, query sqlc.CountFilteredPaymentsParams) (int64, error)

FetchHtlcAttemptsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptsForPaymentsRow, error)
FetchHtlcAttemptResolutionsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptResolutionsForPaymentsRow, error)
Expand Down Expand Up @@ -690,11 +691,35 @@ func (s *SQLStore) QueryPayments(ctx context.Context, query Query) (Response,
return row.Payment.ID
}

// Default date bounds: epoch start and far future. These are always
// provided so the SQL query uses simple comparisons instead of COALESCE
// (which causes type mismatch on Postgres) or OR-based optional filters
// (which can prevent index usage).
createdAfter := time.Unix(0, 0).UTC()
if query.CreationDateStart != 0 {
createdAfter = time.Unix(query.CreationDateStart, 0).UTC()
}

createdBefore := time.Date(
9999, 12, 31, 23, 59, 59, 0, time.UTC,
)
if query.CreationDateEnd != 0 {
createdBefore = time.Unix(query.CreationDateEnd, 0).UTC()
}

err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
// We first count all payments to determine the total count
// if requested.
if query.CountTotal {
totalPayments, err := db.CountPayments(ctx)
totalPayments, err := db.CountFilteredPayments(
ctx, sqlc.CountFilteredPaymentsParams{
CreatedAfter: createdAfter,
CreatedBefore: createdBefore,
IntentType: sqldb.SQLInt16(
PaymentIntentTypeBolt11,
),
},
)
if err != nil {
return fmt.Errorf("failed to count "+
"payments: %w", err)
Expand Down Expand Up @@ -768,28 +793,6 @@ func (s *SQLStore) QueryPayments(ctx context.Context, query Query) (Response,
queryFunc := func(ctx context.Context, lastID int64,
limit int32) ([]sqlc.FilterPaymentsRow, error) {

// Default date bounds: epoch start and far
// future. These are always provided so the SQL
// query uses simple comparisons instead of
// COALESCE (which causes type mismatch on
// Postgres) or OR-based optional filters (which
// can prevent index usage).
createdAfter := time.Unix(0, 0).UTC()
if query.CreationDateStart != 0 {
createdAfter = time.Unix(
query.CreationDateStart, 0,
).UTC()
}

createdBefore := time.Date(
9999, 12, 31, 23, 59, 59, 0, time.UTC,
)
if query.CreationDateEnd != 0 {
createdBefore = time.Unix(
query.CreationDateEnd, 0,
).UTC()
}

filterParams := sqlc.FilterPaymentsParams{
NumLimit: limit,
CreatedAfter: createdAfter,
Expand Down
29 changes: 29 additions & 0 deletions sqldb/sqlc/payments.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions sqldb/sqlc/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions sqldb/sqlc/queries/payments.sql
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ WHERE p.id > COALESCE(sqlc.narg('index_offset_get'), -1)
ORDER BY p.id DESC
LIMIT @num_limit;

-- name: CountFilteredPayments :one
SELECT COUNT(*)
FROM payments p
LEFT JOIN payment_intents i ON i.payment_id = p.id
WHERE p.created_at >= @created_after
AND p.created_at <= @created_before
AND (
i.intent_type = sqlc.narg('intent_type') OR
sqlc.narg('intent_type') IS NULL OR i.intent_type IS NULL
);

-- name: FetchPayment :one
SELECT
sqlc.embed(p),
Expand Down
Loading