From e9cf1796250a6705b8b5d2d2b511c4391a9de3df Mon Sep 17 00:00:00 2001 From: bitromortac Date: Tue, 5 May 2026 14:05:07 +0200 Subject: [PATCH 01/11] lnwire: generalize pure TLV signed-range filtering Add UnsignedRangeFunc and the SerialiseFieldsToSignFn / ExtraSignedFieldsFromTypeMapFn variants so callers with non-BOLT 7 v2 signed ranges (e.g. BOLT 12, which reserves only 240-1000) can plug in their own predicate. The existing SerialiseFieldsToSign and ExtraSignedFieldsFromTypeMap entry points keep their behaviour by delegating to the Fn variants with InUnsignedRange. --- lnwire/pure_tlv.go | 76 +++++++++++++++++++++++------------- lnwire/pure_tlv_test.go | 86 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 26 deletions(-) diff --git a/lnwire/pure_tlv.go b/lnwire/pure_tlv.go index 8e6f7bd9fc3..6692ac2f53c 100644 --- a/lnwire/pure_tlv.go +++ b/lnwire/pure_tlv.go @@ -23,12 +23,12 @@ const ( ) // PureTLVMessage describes an LN message that is a pure TLV stream. If the -// message includes a signature, it will sign all the TLV records in the -// inclusive ranges: 0 to 159 and 1000000000 to 2999999999. +// message includes a signature, the signature covers a subset of the records, +// which subset is determined by the protocol's signed/unsigned range (see +// SerialiseFieldsToSignFn). type PureTLVMessage interface { - // AllRecords returns all the TLV records for the message. This will - // include all the records we know about along with any that we don't - // know about but that fall in the signed TLV range. + // AllRecords returns all the TLV records for the message, including + // both records we know about and unknown records that we preserve. AllRecords() []tlv.Record } @@ -37,13 +37,27 @@ func EncodePureTLVMessage(msg PureTLVMessage, buf *bytes.Buffer) error { return EncodeRecordsTo(buf, msg.AllRecords()) } +// UnsignedRangeFunc returns true when a TLV type is in the unsigned range of a +// pure-TLV message (i.e., excluded from the signature). Each protocol supplies +// its own predicate to encode the boundary between signed and unsigned types. +type UnsignedRangeFunc func(tlv.Type) bool + // SerialiseFieldsToSign serialises all the records from the given -// PureTLVMessage that fall within the signed TLV range. +// PureTLVMessage that fall within the BOLT 7 v2 signed TLV range. Use +// SerialiseFieldsToSignFn for a protocol with a different boundary. func SerialiseFieldsToSign(msg PureTLVMessage) ([]byte, error) { - // Filter out all the fields not in the signed ranges. + return SerialiseFieldsToSignFn(msg, InUnsignedRange) +} + +// SerialiseFieldsToSignFn serialises all the records from the given +// PureTLVMessage that the supplied predicate keeps in the signed range. A type +// for which isUnsigned returns true is excluded from the digest. +func SerialiseFieldsToSignFn(msg PureTLVMessage, + isUnsigned UnsignedRangeFunc) ([]byte, error) { + var signedRecords []tlv.Record for _, record := range msg.AllRecords() { - if InUnsignedRange(record.Type()) { + if isUnsigned(record.Type()) { continue } @@ -58,8 +72,9 @@ func SerialiseFieldsToSign(msg PureTLVMessage) ([]byte, error) { return buf.Bytes(), nil } -// InUnsignedRange returns true if the given TLV type falls outside the TLV -// ranges that the signature of a pure TLV message will cover. +// InUnsignedRange is the BOLT 7 v2 UnsignedRangeFunc: it returns true for types +// in 160-999_999_999 or 3_000_000_000+, which sit outside the BOLT 7 v2 signed +// ranges (0-159 and 1_000_000_000-2_999_999_999). func InUnsignedRange(t tlv.Type) bool { return (t >= pureTLVUnsignedRangeOneStart && t < pureTLVSignedSecondRangeStart) || @@ -72,32 +87,41 @@ func InUnsignedRange(t tlv.Type) bool { // for re-composing the wire message since the signature covers these fields. type ExtraSignedFields map[uint64][]byte -// ExtraSignedFieldsFromTypeMap is a helper that can be used alongside calls to -// the tlv.Stream DecodeWithParsedTypesP2P or DecodeWithParsedTypes methods to -// extract the tlv type and value pairs in the defined PureTLVMessage signed -// range which we have not handled with any of our defined Records. These -// methods will return a tlv.TypeMap containing the records that were extracted -// from an io.Reader. If the record was know and handled by a defined record, -// then the value accompanying the record's type in the map will be nil. -// Otherwise, if the record was unhandled, it will be non-nil. +// ExtraSignedFieldsFromTypeMap returns the unhandled signed-range entries from +// a tlv.TypeMap (as returned by DecodeWithParsedTypes(P2P)) so the caller can +// re-emit them and keep the message signature valid. It uses the BOLT 7 v2 +// signed range; use ExtraSignedFieldsFromTypeMapFn for a different boundary. func ExtraSignedFieldsFromTypeMap(m tlv.TypeMap) ExtraSignedFields { + return ExtraSignedFieldsFromTypeMapFn(m, InUnsignedRange) +} + +// ExtraSignedFieldsFromTypeMapFn returns the unhandled entries from a +// tlv.TypeMap that the supplied predicate keeps in the signed range, so the +// caller can re-emit them and keep the message signature valid. Entries for +// which isUnsigned returns true are dropped. +func ExtraSignedFieldsFromTypeMapFn(m tlv.TypeMap, + isUnsigned UnsignedRangeFunc) ExtraSignedFields { + extraFields := make(ExtraSignedFields) for t, v := range m { - // If the value in the type map is nil, then it indicates that - // we know this type, and it was handled by one of the records - // we passed to the decode function vai the TLV stream. + // A nil value signals that this type was consumed by one of the + // typed records passed to the TLV stream decoder, so its bytes + // are already represented elsewhere and do not need to be + // tracked here. if v == nil { continue } - // No need to keep this field if it is unknown to us and is not - // in the sign range. - if InUnsignedRange(t) { + // Types the predicate places outside the signed range fall + // outside the signature's coverage, so they do not need to + // survive into re-encoding. + if isUnsigned(t) { continue } - // Otherwise, this is an un-handled type, so we keep track of - // it for signature validation and re-encoding later on. + // The remaining types are unhandled but within the signed + // range; preserve their raw bytes so the message can re-emit + // them verbatim and the signature stays valid. extraFields[uint64(t)] = v } diff --git a/lnwire/pure_tlv_test.go b/lnwire/pure_tlv_test.go index a81a89ecb6d..9148678d2a0 100644 --- a/lnwire/pure_tlv_test.go +++ b/lnwire/pure_tlv_test.go @@ -387,3 +387,89 @@ func (g *MsgV2) AllRecords() []tlv.Record { return ProduceRecordsSorted(recordProducers...) } + +// mockPureTLVMessage is a minimal PureTLVMessage backed by a fixed record +// slice, used to exercise the predicate-driven helpers. +type mockPureTLVMessage struct { + records []tlv.Record +} + +func (m *mockPureTLVMessage) AllRecords() []tlv.Record { + return m.records +} + +// TestSerialiseFieldsToSignFn verifies that the serialiser correctly filters +// records based on the provided predicate before encoding. +func TestSerialiseFieldsToSignFn(t *testing.T) { + t.Parallel() + + var ( + signedVal uint16 = 11 + unsignedVal uint16 = 22 + ) + + msg := &mockPureTLVMessage{ + records: []tlv.Record{ + tlv.MakePrimitiveRecord(5, &signedVal), + tlv.MakePrimitiveRecord(10, &unsignedVal), + }, + } + + // Predicate that defines type 10 as unsigned (excluded). + isUnsigned := func(typ tlv.Type) bool { + return typ == 10 + } + + encoded, err := SerialiseFieldsToSignFn(msg, isUnsigned) + require.NoError(t, err) + + // Only type 5 should be encoded (type 5, length 2, value 11). + require.Equal(t, []byte{0x05, 0x02, 0x00, 0x0b}, encoded) +} + +// TestExtraSignedFieldsFromTypeMapFn confirms the predicate-driven variant +// keeps and drops the right type ranges for callers whose signed range is not +// the BOLT 7 v2 default. It also locks in the round-trip identity with the +// convenience wrapper. +func TestExtraSignedFieldsFromTypeMapFn(t *testing.T) { + t.Parallel() + + // Bolt12 signature TLVs sit at 240-1000 and are excluded from the + // signed Merkle tree. Everything else is signed. + bolt12Unsigned := func(typ tlv.Type) bool { + return typ >= 240 && typ <= 1000 + } + + typeMap := tlv.TypeMap{ + // Handled by a typed record on the receiver. + tlv.Type(2): nil, + + // Unknown type in the bolt12 signed range — must survive. + tlv.Type(99): { + 0x01, + }, + + // Bolt12 signature TLV — must be dropped. + tlv.Type(240): { + 0x02, + }, + + // Bolt12 second-range type — signed for bolt12, signed for the + // BOLT 7 v2 default too. + tlv.Type(1_500_000_000): { + 0x03, + }, + } + + gotBolt12 := ExtraSignedFieldsFromTypeMapFn(typeMap, bolt12Unsigned) + require.Len(t, gotBolt12, 2) + require.Equal(t, []byte{0x01}, gotBolt12[99]) + require.Equal(t, []byte{0x03}, gotBolt12[1_500_000_000]) + + gotDefault := ExtraSignedFieldsFromTypeMap(typeMap) + // In the BOLT 7 v2 range, type 99 is signed but type 240 is unsigned. + require.Len(t, gotDefault, 2) + require.Equal(t, []byte{0x01}, gotDefault[99]) + require.Equal(t, []byte{0x03}, gotDefault[1_500_000_000]) + require.NotContains(t, gotDefault, uint64(240)) +} From a8b506b7198acd87c54a0fd985a7255ea4ce1fe5 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Tue, 5 May 2026 14:05:22 +0200 Subject: [PATCH 02/11] lnwire: add SetOptFromMap and AddOpt This gives us easier optional tlv field handling, which we will use for the following message definitions. --- lnwire/custom_records.go | 26 +++++++++++++++++++++ lnwire/custom_records_test.go | 43 +++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/lnwire/custom_records.go b/lnwire/custom_records.go index de5ff4a2302..90f99d264e6 100644 --- a/lnwire/custom_records.go +++ b/lnwire/custom_records.go @@ -263,6 +263,32 @@ func DecodeRecordsP2P(r *bytes.Reader, return tlvStream.DecodeWithParsedTypesP2P(r) } +// AddOpt appends a record producer for the given optional record to producers +// when the optional is set, leaving producers unchanged otherwise. +func AddOpt[T tlv.TlvType, V any](producers *[]tlv.RecordProducer, + opt tlv.OptionalRecordT[T, V]) { + + opt.WhenSome( + func(r tlv.RecordT[T, V]) { + *producers = append(*producers, &r) + }, + ) +} + +// SetOptFromMap marks target as Some(record) when record's TLV type appeared +// on the wire (i.e., is a key in the decoded TypeMap). +// +// The caller must have passed record to the underlying Stream before decoding; +// otherwise record.Val will not have been populated, and wrapping it as Some +// would yield a zero-valued field. +func SetOptFromMap[T tlv.TlvType, V any](typeMap tlv.TypeMap, + target *tlv.OptionalRecordT[T, V], record tlv.RecordT[T, V]) { + + if _, ok := typeMap[record.TlvType()]; ok { + *target = tlv.SomeRecordT(record) + } +} + // AssertUniqueTypes asserts that the given records have unique types. func AssertUniqueTypes(r []tlv.Record) error { seen := make(fn.Set[tlv.Type], len(r)) diff --git a/lnwire/custom_records_test.go b/lnwire/custom_records_test.go index d4aad2e5462..d14586b8e17 100644 --- a/lnwire/custom_records_test.go +++ b/lnwire/custom_records_test.go @@ -249,3 +249,46 @@ func TestCustomRecordsMergedCopy(t *testing.T) { }) } } + +// TestAddOptAppendsOnlyWhenSet checks that AddOpt is a no-op for an empty +// optional and appends a producer when the optional is populated. +func TestAddOptAppendsOnlyWhenSet(t *testing.T) { + t.Parallel() + + var producers []tlv.RecordProducer + + emptyOpt := tlv.OptionalRecordT[tlv.TlvType1, uint16]{} + AddOpt(&producers, emptyOpt) + require.Empty(t, producers) + + setOpt := tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1, uint16](42), + ) + AddOpt(&producers, setOpt) + require.Len(t, producers, 1) + + rec := producers[0].Record() + require.Equal(t, tlv.Type(1), rec.Type()) +} + +// TestSetOptFromMapUsesTypeMapPresence verifies that SetOptFromMap populates +// only when the TLV type is present in the TypeMap. +func TestSetOptFromMapUsesTypeMapPresence(t *testing.T) { + t.Parallel() + + present := tlv.TypeMap{tlv.Type(1): nil} + missing := tlv.TypeMap{} + + var target tlv.OptionalRecordT[tlv.TlvType1, uint16] + SetOptFromMap( + missing, &target, + tlv.NewPrimitiveRecord[tlv.TlvType1, uint16](7), + ) + require.True(t, target.IsNone()) + + SetOptFromMap( + present, &target, + tlv.NewPrimitiveRecord[tlv.TlvType1, uint16](7), + ) + require.True(t, target.IsSome()) +} From ad1fbd1c99386fc135a65c2895846f4411da050c Mon Sep 17 00:00:00 2001 From: bitromortac Date: Mon, 11 May 2026 11:27:04 +0200 Subject: [PATCH 03/11] lnwire: add bounded introNode BlindedPath codec Introduce the canonical lnwire.BlindedPath / BlindedPaths codec with a sealed IntroductionNode sum-type covering both the BOLT 4 pubkey and sciddir variants. The codec gates every variable-length subfield against an io.LimitedReader. It fails closed on the encoder side so invalid input never hits the wire. This commit is a pure addition: no existing caller changes. Subsequent commits migrate OnionMessagePayload and the bolt12 message structs to consume the new codec. --- lnwire/blinded_path.go | 333 +++++++++++++++++++++++++++++ lnwire/blinded_path_test.go | 414 ++++++++++++++++++++++++++++++++++++ lnwire/bounds.go | 35 +++ lnwire/intro_node.go | 150 +++++++++++++ 4 files changed, 932 insertions(+) create mode 100644 lnwire/blinded_path.go create mode 100644 lnwire/blinded_path_test.go create mode 100644 lnwire/bounds.go create mode 100644 lnwire/intro_node.go diff --git a/lnwire/blinded_path.go b/lnwire/blinded_path.go new file mode 100644 index 00000000000..14e247af479 --- /dev/null +++ b/lnwire/blinded_path.go @@ -0,0 +1,333 @@ +package lnwire + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // ErrInvalidIntroNode is returned when a blinded path's introduction + // node discriminator is not one of the spec-defined values. + ErrInvalidIntroNode = errors.New("invalid blinded-path introduction " + + "node discriminator") + + // ErrEmptyBlindedPath is returned when a blinded path has zero hops. + ErrEmptyBlindedPath = errors.New("blinded path with zero hops") +) + +// BlindedPath holds the introduction node, blinding point, and encrypted hops +// of a single blinded path. +type BlindedPath struct { + // IntroductionNode is the variant-defined introduction node for this + // blinded path. + IntroductionNode IntroductionNode + + // BlindingPoint is the blinding point for this path, used to derive the + // blinded node IDs and encrypt the hop payloads. + BlindingPoint *btcec.PublicKey + + // Hops is the ordered list of blinded hops in this path. + Hops []BlindedHop +} + +// BlindedPaths holds one or more blinded paths. +type BlindedPaths struct { + Paths []BlindedPath +} + +// BlindedHop represents a single hop in a blinded path. +type BlindedHop struct { + // BlindedNodeID is the blinded public key for this hop. + BlindedNodeID *btcec.PublicKey + + // EncryptedData is the encrypted payload for this hop. + EncryptedData []byte +} + +var ( + _ tlv.RecordProducer = (*BlindedPath)(nil) + _ tlv.RecordProducer = (*BlindedPaths)(nil) +) + +// Record returns a TLV record for a single BlindedPath at the BOLT 4 reply_path +// TLV type. Used directly by OnionMessagePayload's reply_path encoding. +func (p *BlindedPath) Record() tlv.Record { + return tlv.MakeDynamicRecord( + replyPathType, p, + func() uint64 { + return blindedPathSize(p) + }, + encodeBlindedPath, + decodeBlindedPath, + ) +} + +// blindedPathSize returns the on-wire size of a single BlindedPath. +func blindedPathSize(p *BlindedPath) uint64 { + var introLen uint64 + if p.IntroductionNode != nil { + introLen = p.IntroductionNode.encodedLen() + } + + // introduction_node (variant-defined) + blinding_point (33) + + // num_hops (1). + size := introLen + pubKeyLen + 1 + for _, h := range p.Hops { + // blinded_node_id (33) + enclen (2) + enc_data. + size += pubKeyLen + 2 + uint64(len(h.EncryptedData)) + } + + return size +} + +// encodeBlindedPath writes a single blinded path. No bytes are written if the +// path fails validation. +func encodeBlindedPath(w io.Writer, val any, buf *[8]byte) error { + p, ok := val.(*BlindedPath) + if !ok { + return fmt.Errorf("expected *BlindedPath, got %T", val) + } + + return writeBlindedPath(w, p, buf) +} + +// writeBlindedPath validates the path and writes a single blinded path to w. +func writeBlindedPath(w io.Writer, p *BlindedPath, buf *[8]byte) error { + if p.IntroductionNode == nil { + return fmt.Errorf("nil intro node") + } + + if err := p.IntroductionNode.validate(); err != nil { + return err + } + + if p.BlindingPoint == nil { + return fmt.Errorf("nil blinding point") + } + + if !p.BlindingPoint.IsOnCurve() { + return fmt.Errorf("blinding point not on curve") + } + + if len(p.Hops) == 0 { + return ErrEmptyBlindedPath + } + if len(p.Hops) > maxBlindedPathHops { + return fmt.Errorf("%d hops exceeds limit %d", len(p.Hops), + maxBlindedPathHops) + } + + if err := p.IntroductionNode.encode(w); err != nil { + return err + } + blindingBytes := p.BlindingPoint.SerializeCompressed() + if _, err := w.Write(blindingBytes); err != nil { + return err + } + + buf[0] = uint8(len(p.Hops)) + if _, err := w.Write(buf[:1]); err != nil { + return err + } + + for hIdx := range p.Hops { + if err := writeBlindedHop(w, &p.Hops[hIdx], buf); err != nil { + return fmt.Errorf("hop %d: %w", hIdx, err) + } + } + + return nil +} + +// decodeBlindedPath reads a single blinded path framed at the TLV-value level. +func decodeBlindedPath(r io.Reader, val any, buf *[8]byte, l uint64) error { + p, ok := val.(*BlindedPath) + if !ok { + return fmt.Errorf("expected *BlindedPath, got %T", val) + } + + lr := &io.LimitedReader{R: r, N: int64(l)} + + if err := readBlindedPath(lr, p, buf); err != nil { + return err + } + + if lr.N != 0 { + return fmt.Errorf("trailing %d bytes after blinded path", lr.N) + } + + return nil +} + +// readBlindedPath decodes a single blinded path from lr. +func readBlindedPath(lr *io.LimitedReader, p *BlindedPath, + buf *[8]byte) error { + + intro, err := decodeIntroductionNode(lr, buf) + if err != nil { + return err + } + p.IntroductionNode = intro + + var blindingBytes [pubKeyLen]byte + if _, err := io.ReadFull(lr, blindingBytes[:]); err != nil { + return fmt.Errorf("read blinding point: %w", err) + } + blinding, err := btcec.ParsePubKey(blindingBytes[:]) + if err != nil { + return fmt.Errorf("blinding point: %w", err) + } + p.BlindingPoint = blinding + + if _, err := io.ReadFull(lr, buf[:1]); err != nil { + return fmt.Errorf("read num_hops: %w", err) + } + numHops := int(buf[0]) + if numHops == 0 { + return ErrEmptyBlindedPath + } + + if int64(numHops)*minBlindedHopBytes > lr.N { + return fmt.Errorf("num_hops %d exceeds remaining %d bytes", + numHops, lr.N) + } + + p.Hops = make([]BlindedHop, numHops) + for i := range p.Hops { + if err := readBlindedHop(lr, &p.Hops[i], buf); err != nil { + return err + } + } + + return nil +} + +// Record returns a TLV record for BlindedPaths. +func (bp *BlindedPaths) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, bp, + func() uint64 { + return blindedPathsSize(bp) + }, + encodeBlindedPaths, + decodeBlindedPaths, + ) +} + +// blindedPathsSize returns the on-wire size of multiple BlindedPaths. +func blindedPathsSize(bp *BlindedPaths) uint64 { + var size uint64 + for i := range bp.Paths { + size += blindedPathSize(&bp.Paths[i]) + } + + return size +} + +// encodeBlindedPaths writes the multi-path TLV value as concatenated paths. +// Fails closed under the same conditions as encodeBlindedPath. +func encodeBlindedPaths(w io.Writer, val any, buf *[8]byte) error { + bp, ok := val.(*BlindedPaths) + if !ok { + return fmt.Errorf("expected *BlindedPaths, got %T", val) + } + + for pIdx := range bp.Paths { + err := writeBlindedPath(w, &bp.Paths[pIdx], buf) + if err != nil { + return fmt.Errorf("blinded path %d: %w", pIdx, err) + } + } + + return nil +} + +// decodeBlindedPaths reads concatenated blinded paths. The LimitedReader gates +// each variable-length subfield against the bytes still on the wire, so an +// oversize hop count cannot force a large allocation before io.ReadFull +// notices the bytes are absent. +func decodeBlindedPaths(r io.Reader, val any, buf *[8]byte, l uint64) error { + bp, ok := val.(*BlindedPaths) + if !ok { + return fmt.Errorf("expected *BlindedPaths, got %T", val) + } + + lr := &io.LimitedReader{R: r, N: int64(l)} + + for lr.N > 0 { + var p BlindedPath + if err := readBlindedPath(lr, &p, buf); err != nil { + return err + } + bp.Paths = append(bp.Paths, p) + } + + return nil +} + +// writeBlindedHop emits BlindedNodeID + enclen + encrypted data. The size cap +// is checked first so no bytes hit the writer on rejection. +func writeBlindedHop(w io.Writer, h *BlindedHop, buf *[8]byte) error { + if h.BlindedNodeID == nil { + return fmt.Errorf("nil blinded node id") + } + + if !h.BlindedNodeID.IsOnCurve() { + return fmt.Errorf("blinded node id not on curve") + } + + if len(h.EncryptedData) > maxEncryptedDataLen { + return fmt.Errorf("encrypted data %d exceeds limit %d", + len(h.EncryptedData), maxEncryptedDataLen) + } + + nodeIDBytes := h.BlindedNodeID.SerializeCompressed() + if _, err := w.Write(nodeIDBytes); err != nil { + return err + } + + binary.BigEndian.PutUint16(buf[:2], uint16(len(h.EncryptedData))) + if _, err := w.Write(buf[:2]); err != nil { + return err + } + if _, err := w.Write(h.EncryptedData); err != nil { + return err + } + + return nil +} + +// readBlindedHop decodes a single blinded hop. The enclen guard against lr.N +// bounds the EncryptedData allocation. +func readBlindedHop(lr *io.LimitedReader, h *BlindedHop, buf *[8]byte) error { + var nodeBytes [pubKeyLen]byte + if _, err := io.ReadFull(lr, nodeBytes[:]); err != nil { + return fmt.Errorf("read blinded node: %w", err) + } + node, err := btcec.ParsePubKey(nodeBytes[:]) + if err != nil { + return fmt.Errorf("blinded node id: %w", err) + } + h.BlindedNodeID = node + + if _, err := io.ReadFull(lr, buf[:2]); err != nil { + return fmt.Errorf("read enclen: %w", err) + } + encLen := binary.BigEndian.Uint16(buf[:2]) + if int64(encLen) > lr.N { + return fmt.Errorf("enclen %d exceeds remaining %d", encLen, + lr.N) + } + + h.EncryptedData = make([]byte, encLen) + if _, err := io.ReadFull(lr, h.EncryptedData); err != nil { + return fmt.Errorf("read encrypted data: %w", err) + } + + return nil +} diff --git a/lnwire/blinded_path_test.go b/lnwire/blinded_path_test.go new file mode 100644 index 00000000000..207e08a45a0 --- /dev/null +++ b/lnwire/blinded_path_test.go @@ -0,0 +1,414 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/stretchr/testify/require" +) + +// validPubkeyIntro returns an on-curve PubkeyIntro plus the matching +// *btcec.PublicKey for assertions. +func validPubkeyIntro(t *testing.T) (PubkeyIntro, *btcec.PublicKey) { + t.Helper() + + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + pub := priv.PubKey() + + return PubkeyIntro{Pubkey: pub}, pub +} + +// validBlindingPoint returns an on-curve pubkey suitable for use as a +// BlindingPoint or BlindedNodeID in tests. +func validBlindingPoint(t *testing.T) *btcec.PublicKey { + t.Helper() + + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + return priv.PubKey() +} + +// oversizeEncDataPaths returns a BlindedPaths with a single hop whose +// EncryptedData is one byte over the wire-format limit, used by the +// encode-rejects test. +func oversizeEncDataPaths(t *testing.T, intro IntroductionNode) *BlindedPaths { + t.Helper() + + return &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: intro, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{{ + BlindedNodeID: validBlindingPoint(t), + EncryptedData: make( + []byte, maxEncryptedDataLen+1, + ), + }}, + }}, + } +} + +// TestBlindedPathRoundTrip pins encode→decode parity across both +// IntroductionNode variants and across single- and multi-path framings, so +// concrete variant types survive the round-trip with byte-identical output. +func TestBlindedPathRoundTrip(t *testing.T) { + t.Parallel() + + pubkeyIntro, _ := validPubkeyIntro(t) + sciddirIntro := SciddirIntro{ + Direction: 0x01, + SCID: [8]byte{ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + }, + } + + hop := func(payload byte) BlindedHop { + return BlindedHop{ + BlindedNodeID: validBlindingPoint(t), + EncryptedData: []byte{payload, payload ^ 0xff}, + } + } + + pubkeyPath := BlindedPath{ + IntroductionNode: pubkeyIntro, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{ + hop(0xde), + hop(0xad), + }, + } + sciddirPath := BlindedPath{ + IntroductionNode: sciddirIntro, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{hop(0xbe)}, + } + + tests := []struct { + name string + paths []BlindedPath + }{ + { + name: "single pubkey path", + paths: []BlindedPath{pubkeyPath}, + }, + { + name: "single sciddir path", + paths: []BlindedPath{sciddirPath}, + }, + { + name: "mixed multi-path", + paths: []BlindedPath{pubkeyPath, sciddirPath}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + bp := &BlindedPaths{Paths: tc.paths} + + var buf bytes.Buffer + require.NoError(t, encodeBlindedPaths( + &buf, bp, new([8]byte), + )) + + var decoded BlindedPaths + err := decodeBlindedPaths( + bytes.NewReader(buf.Bytes()), &decoded, + new([8]byte), uint64(buf.Len()), + ) + require.NoError(t, err) + require.Equal(t, bp.Paths, decoded.Paths) + + // Single-path framing must round-trip too: the + // reply_path TLV carries one BlindedPath, not a list. + if len(tc.paths) == 1 { + var single bytes.Buffer + require.NoError(t, encodeBlindedPath( + &single, &tc.paths[0], new([8]byte), + )) + + var decodedSingle BlindedPath + err := decodeBlindedPath( + bytes.NewReader(single.Bytes()), + &decodedSingle, new([8]byte), + uint64(single.Len()), + ) + require.NoError(t, err) + require.Equal( + t, tc.paths[0], decodedSingle, + ) + } + }) + } +} + +// TestDecodeBlindedPathsRejects covers every malformed-input branch the +// decoder must refuse: bad discriminators, allocation bombs, and short reads. +// The catch-all is that the decoder never allocates more memory than the +// remaining wire bytes can justify. +func TestDecodeBlindedPathsRejects(t *testing.T) { + t.Parallel() + + // validKey is a 33-byte compressed SEC1 pubkey that the on-curve + // decoder accepts; reused as both intro pubkey and blinding point so + // the tests can exercise post-pubkey decode branches. + validKey := validBlindingPoint(t).SerializeCompressed() + + // hopAllocOverflow declares num_hops=255 with no hop payload. Without + // the remaining-bytes guard the decoder would make([]BlindedHop, 255) + // before io.ReadFull notices the bytes are absent. + hopAllocOverflow := func() []byte { + out := make([]byte, 0, 67) + out = append(out, validKey...) + out = append(out, validKey...) + out = append(out, 0xff) + + return out + } + + // enclenOverflow declares enclen=65535 on a hop with no payload. The + // guard against lr.N must reject before make([]byte, 65535). + enclenOverflow := func() []byte { + out := make([]byte, 0, 70) + out = append(out, validKey...) + out = append(out, validKey...) + out = append(out, 0x01) + out = append(out, validKey...) + out = append(out, 0xff, 0xff) + + return out + } + + // shortIntroPubkey truncates after the discriminator + 5 of 33 bytes + // of intro pubkey, exercising io.ReadFull's short-read error. + shortIntroPubkey := func() []byte { + return append([]byte{0x02}, bytes.Repeat([]byte{0x00}, 5)...) + } + + // shortBlindingPoint truncates after a full intro pubkey plus 5 of the + // 33 blinding-point bytes, exercising io.ReadFull's short-read path + // past the discriminator. + shortBlindingPoint := func() []byte { + out := make([]byte, 0, pubKeyLen+5) + out = append(out, validKey...) + out = append(out, bytes.Repeat([]byte{0x00}, 5)...) + + return out + } + + tests := []struct { + name string + data []byte + wantErr error + wantMsg []string + }{ + { + name: "invalid discriminator 0x04", + data: []byte{0x04}, + wantErr: ErrInvalidIntroNode, + }, + { + name: "invalid discriminator 0x05", + data: []byte{0x05}, + wantErr: ErrInvalidIntroNode, + }, + { + name: "invalid discriminator 0xff", + data: []byte{0xff}, + wantErr: ErrInvalidIntroNode, + }, + { + name: "hop alloc overflow", + data: hopAllocOverflow(), + wantMsg: []string{"num_hops", "exceeds remaining"}, + }, + { + name: "enclen alloc overflow", + data: enclenOverflow(), + wantMsg: []string{"enclen", "exceeds remaining"}, + }, + { + name: "short intro pubkey", + data: shortIntroPubkey(), + wantMsg: []string{"read intro pubkey"}, + }, + { + name: "short blinding point", + data: shortBlindingPoint(), + wantMsg: []string{"read blinding point"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var bp BlindedPaths + err := decodeBlindedPaths( + bytes.NewReader(tc.data), &bp, new([8]byte), + uint64(len(tc.data)), + ) + require.Error(t, err) + + if tc.wantErr != nil { + require.ErrorIs(t, err, tc.wantErr) + } + for _, msg := range tc.wantMsg { + require.Contains(t, err.Error(), msg) + } + }) + } +} + +// TestEncodeBlindedPathsRejects pins the encoder's fail-closed guards. Any +// case here must not emit bytes — invalid input cannot be retracted from the +// wire once flushed. +func TestEncodeBlindedPathsRejects(t *testing.T) { + t.Parallel() + + validIntro, _ := validPubkeyIntro(t) + validHop := BlindedHop{BlindedNodeID: validBlindingPoint(t)} + + tests := []struct { + name string + paths *BlindedPaths + wantErr error + wantMsg []string + wantNoWrite bool + }{ + { + name: "nil intro", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{validHop}, + }}, + }, + wantMsg: []string{"nil intro node"}, + wantNoWrite: true, + }, + { + name: "nil pubkey in PubkeyIntro", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: PubkeyIntro{}, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{ + validHop, + }, + }}, + }, + wantErr: ErrInvalidIntroNode, + wantNoWrite: true, + }, + { + name: "invalid sciddir direction 0x02", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: SciddirIntro{ + Direction: 0x02, + }, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{validHop}, + }}, + }, + wantErr: ErrInvalidIntroNode, + wantNoWrite: true, + }, + { + name: "invalid sciddir direction 0xff", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: SciddirIntro{ + Direction: 0xff, + }, + BlindingPoint: validBlindingPoint(t), + Hops: []BlindedHop{validHop}, + }}, + }, + wantErr: ErrInvalidIntroNode, + wantNoWrite: true, + }, + { + name: "nil blinding point", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: validIntro, + Hops: []BlindedHop{ + validHop, + }, + }}, + }, + wantMsg: []string{"nil blinding point"}, + wantNoWrite: true, + }, + { + name: "zero hops", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: validIntro, + BlindingPoint: validBlindingPoint(t), + Hops: nil, + }}, + }, + wantErr: ErrEmptyBlindedPath, + wantNoWrite: true, + }, + { + name: "hop overflow", + paths: &BlindedPaths{ + Paths: []BlindedPath{{ + IntroductionNode: validIntro, + BlindingPoint: validBlindingPoint(t), + Hops: func() []BlindedHop { + hops := make([]BlindedHop, + maxBlindedPathHops+1) + pub := validBlindingPoint(t) + for i := range hops { + // Write to hop. + h := &hops[i] + h.BlindedNodeID = pub + } + + return hops + }(), + }}, + }, + wantMsg: []string{"exceeds limit"}, + wantNoWrite: true, + }, + { + name: "oversize encrypted data", + paths: oversizeEncDataPaths(t, validIntro), + wantMsg: []string{"exceeds limit"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + err := encodeBlindedPaths( + &buf, tc.paths, new([8]byte), + ) + require.Error(t, err) + + if tc.wantErr != nil { + require.ErrorIs(t, err, tc.wantErr) + } + for _, msg := range tc.wantMsg { + require.Contains(t, err.Error(), msg) + } + if tc.wantNoWrite { + require.Equal(t, 0, buf.Len(), + "encoder wrote bytes on fail-closed "+ + "path") + } + }) + } +} diff --git a/lnwire/bounds.go b/lnwire/bounds.go new file mode 100644 index 00000000000..8cb79bb4f98 --- /dev/null +++ b/lnwire/bounds.go @@ -0,0 +1,35 @@ +package lnwire + +import ( + "math" + + "github.com/btcsuite/btcd/btcec/v2" +) + +// BOLT 4 blinded-path field bounds. Each constant matches the format ceiling +// imposed by the spec encoding (uint8 num_hops, uint16 enclen). +const ( + // pubKeyLen aliases the upstream compressed-pubkey length for shorter + // usage in this package. + pubKeyLen = btcec.PubKeyBytesLenCompressed + + // sciddirLen is the on-wire length of a sciddir introduction node + // (1-byte direction + 8-byte SCID). + sciddirLen = 9 + + // scidLen is the byte length of a short channel ID. + scidLen = 8 + + // maxBlindedPathHops bounds the number of hops a single blinded path + // may declare. The spec encodes num_hops as a uint8, so 255 is the + // format's absolute ceiling. + maxBlindedPathHops = math.MaxUint8 + + // maxEncryptedDataLen bounds the encrypted-data field in a single + // blinded hop. The spec encodes the length as a uint16. + maxEncryptedDataLen = math.MaxUint16 + + // minBlindedHopBytes is the on-wire footprint of the smallest possible + // blinded hop: BlindedNodeID(33) + enclen(2) + 0 enc_data. + minBlindedHopBytes = pubKeyLen + 2 +) diff --git a/lnwire/intro_node.go b/lnwire/intro_node.go new file mode 100644 index 00000000000..af79fec7ca7 --- /dev/null +++ b/lnwire/intro_node.go @@ -0,0 +1,150 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" +) + +// IntroductionNode is the sealed sum-type for a blinded path's introduction +// node. {0x02, 0x03} → PubkeyIntro; {0x00, 0x01} → SciddirIntro. The unexported +// method seals the variant set so foreign packages cannot satisfy the interface +// with an unrecognised wire form. +type IntroductionNode interface { + isIntroductionNode() + + encodedLen() uint64 + + encode(w io.Writer) error + + // validate checks that the discriminator byte is valid for the variant. + validate() error + + // Bytes returns the wire-format encoding of the introduction node for + // callers that need it outside an io.Writer (RPC surfaces). + Bytes() []byte +} + +// PubkeyIntro is the 33-byte compressed-pubkey variant. The SEC1 parity byte +// (0x02 or 0x03) doubles as the wire discriminator. +type PubkeyIntro struct { + Pubkey *btcec.PublicKey +} + +// SciddirIntro is the 9-byte sciddir variant. Direction is the wire +// discriminator; SCID is the 8-byte short channel ID. +type SciddirIntro struct { + Direction byte + SCID [scidLen]byte +} + +var ( + _ IntroductionNode = PubkeyIntro{} + _ IntroductionNode = SciddirIntro{} +) + +// decodeIntroductionNode reads the discriminator byte and dispatches to the +// matching variant. +func decodeIntroductionNode(r io.Reader, + buf *[8]byte) (IntroductionNode, error) { + + if _, err := io.ReadFull(r, buf[:1]); err != nil { + return nil, fmt.Errorf("read intro node type: %w", err) + } + + disc := buf[0] + switch disc { + case 0x00, 0x01: + s := SciddirIntro{Direction: disc} + if _, err := io.ReadFull(r, s.SCID[:]); err != nil { + return nil, fmt.Errorf("read sciddir: %w", err) + } + + return s, nil + + case 0x02, 0x03: + var b [pubKeyLen]byte + b[0] = disc + if _, err := io.ReadFull(r, b[1:]); err != nil { + return nil, fmt.Errorf("read intro pubkey: %w", err) + } + pub, err := btcec.ParsePubKey(b[:]) + if err != nil { + return nil, fmt.Errorf("%w: %w", + ErrInvalidIntroNode, err) + } + + return PubkeyIntro{Pubkey: pub}, nil + + default: + return nil, fmt.Errorf("%w: 0x%02x", ErrInvalidIntroNode, disc) + } +} + +func (PubkeyIntro) isIntroductionNode() {} + +func (p PubkeyIntro) encodedLen() uint64 { return pubKeyLen } + +func (p PubkeyIntro) encode(w io.Writer) error { + if p.Pubkey == nil { + return fmt.Errorf("nil intro pubkey") + } + _, err := w.Write(p.Pubkey.SerializeCompressed()) + + return err +} + +func (p PubkeyIntro) validate() error { + if p.Pubkey == nil { + return fmt.Errorf("%w: nil pubkey", ErrInvalidIntroNode) + } + + if !p.Pubkey.IsOnCurve() { + return fmt.Errorf("%w: pubkey not on curve", + ErrInvalidIntroNode) + } + + return nil +} + +// Bytes returns the wire-format encoding of the pubkey variant. +func (p PubkeyIntro) Bytes() []byte { + var buf bytes.Buffer + buf.Grow(pubKeyLen) + _ = p.encode(&buf) + + return buf.Bytes() +} + +func (SciddirIntro) isIntroductionNode() {} + +func (s SciddirIntro) encodedLen() uint64 { return sciddirLen } + +func (s SciddirIntro) encode(w io.Writer) error { + if _, err := w.Write([]byte{s.Direction}); err != nil { + return err + } + _, err := w.Write(s.SCID[:]) + + return err +} + +func (s SciddirIntro) validate() error { + switch s.Direction { + case 0x00, 0x01: + return nil + } + + return fmt.Errorf("%w: 0x%02x", ErrInvalidIntroNode, s.Direction) +} + +// Bytes returns the wire-format encoding of the sciddir variant. +func (s SciddirIntro) Bytes() []byte { + var buf bytes.Buffer + buf.Grow(sciddirLen) + _ = s.encode(&buf) + + return buf.Bytes() +} From 7a754767ef0b15a659128bdd3fd8440e79c44e18 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Fri, 8 May 2026 13:09:09 +0200 Subject: [PATCH 04/11] multi: migrate OnionMessagePayload to lnwire.BlindedPath Switch OnionMessagePayload.ReplyPath from *sphinx.BlindedPath to *lnwire.BlindedPath. The reply-path TLV is now produced and consumed by (*lnwire.BlindedPath).Record(), which honours the BOLT 4 sciddir_or_pubkey introduction-node form. The legacy decoder gated on a 67-byte minimum length and silently rejected reply paths whose introduction node used the 9-byte sciddir variant. The legacy replyPathRecord / replyPathSize / encodeReplyPath / decodeReplyPath / blindedHopSize / encodeBlindedHop / decodeBlindedHop helpers and the unused ErrNoHops sentinel are deleted. Consumers update mechanically: routing/route's OnionMessageBlindedPathToSphinxPath replyPath parameter, the onionmessage.OnionMessageUpdate field, the rpcserver onion-message subscription bridge, and the lnwire test utilities now use the lnwire type directly. The new TestOnionMessagePayloadRoundTrip "sciddir intro reply path" subtest pins the BOLT 4 spec fix. --- lnwire/onion_msg_payload.go | 183 ++----------------------------- lnwire/onion_msg_payload_test.go | 125 ++++++++++----------- lnwire/test_utils.go | 49 ++++++--- onionmessage/onion_endpoint.go | 4 +- routing/route/blindedroute.go | 2 +- rpcserver.go | 18 +-- 6 files changed, 116 insertions(+), 265 deletions(-) diff --git a/lnwire/onion_msg_payload.go b/lnwire/onion_msg_payload.go index 64c5abadec3..f91c650d1fe 100644 --- a/lnwire/onion_msg_payload.go +++ b/lnwire/onion_msg_payload.go @@ -7,7 +7,6 @@ import ( "io" "sort" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/tlv" ) @@ -36,22 +35,15 @@ const ( InvoiceErrorNamespaceType tlv.Type = 68 ) -var ( - // ErrNotFinalPayload is returned when a final hop payload is not - // within the correct range. - ErrNotFinalPayload = errors.New("final hop payloads type should be " + - ">= 64") - - // ErrNoHops is returned when we handle a reply path that does not - // have any hops (this makes no sense). - ErrNoHops = errors.New("reply path requires hops") -) +// ErrNotFinalPayload is returned when a final hop payload is not within the +// correct range. +var ErrNotFinalPayload = errors.New("final hop payloads type should be >= 64") // OnionMessagePayload contains the contents of an onion message payload. type OnionMessagePayload struct { // ReplyPath contains a blinded path that can be used to respond to an // onion message. - ReplyPath *sphinx.BlindedPath + ReplyPath *BlindedPath // EncryptedData contains encrypted data for the recipient. EncryptedData []byte @@ -73,7 +65,7 @@ func (o *OnionMessagePayload) Encode() ([]byte, error) { var records []tlv.Record if o.ReplyPath != nil { - records = append(records, replyPathRecord(o.ReplyPath)) + records = append(records, o.ReplyPath.Record()) } if len(o.EncryptedData) != 0 { @@ -131,11 +123,13 @@ func (o *OnionMessagePayload) Decode(r io.Reader) (map[tlv.Type][]byte, error) { TLVType: InvoiceRequestNamespaceType, } ) - // Create a non-nil entry so that we can directly decode into it. - o.ReplyPath = &sphinx.BlindedPath{} + + // replyPath is used for decoding, we will later check if it was + // actually present and assign it to the message struct. + var replyPath BlindedPath records := []tlv.Record{ - replyPathRecord(o.ReplyPath), + replyPath.Record(), tlv.MakePrimitiveRecord( encryptedDataTLVType, &o.EncryptedData, ), @@ -171,9 +165,8 @@ func (o *OnionMessagePayload) Decode(r io.Reader) (map[tlv.Type][]byte, error) { return tlvMap, fmt.Errorf("decode stream: %w", err) } - // If our reply path wasn't populated, replace it with a nil entry. - if _, ok := tlvMap[replyPathType]; !ok { - o.ReplyPath = nil + if _, ok := tlvMap[replyPathType]; ok { + o.ReplyPath = &replyPath } // Once we're decoded our message, we want to also include any tlvs @@ -258,155 +251,3 @@ func (f *FinalHopTLV) Validate() error { return nil } - -// replyPathRecord produces a tlv record for a reply path. -func replyPathRecord(r *sphinx.BlindedPath) tlv.Record { - return tlv.MakeDynamicRecord( - replyPathType, r, replyPathSize(r), encodeReplyPath, - decodeReplyPath, - ) -} - -// replyPathSize returns the encoded size of a reply path. -func replyPathSize(r *sphinx.BlindedPath) func() uint64 { - return func() uint64 { - // First node pubkey 33 + blinding point pubkey 33 + 1 byte for - // uint8 for our hop count. - size := uint64(33 + 33 + 1) - - // Add each hop's size to our total. - for _, hop := range r.BlindedHops { - size += blindedHopSize(hop) - } - - return size - } -} - -// encodeReplyPath encodes a reply path tlv. -func encodeReplyPath(w io.Writer, val interface{}, buf *[8]byte) error { - if p, ok := val.(*sphinx.BlindedPath); ok { - err := tlv.EPubKey(w, &p.IntroductionPoint, buf) - if err != nil { - return fmt.Errorf("encode first node id: %w", err) - } - - if err := tlv.EPubKey(w, &p.BlindingPoint, buf); err != nil { - return fmt.Errorf("encode blinding point: %w", err) - } - - hopCount := uint8(len(p.BlindedHops)) - if hopCount == 0 { - return ErrNoHops - } - - if err := tlv.EUint8(w, &hopCount, buf); err != nil { - return fmt.Errorf("encode hop count: %w", err) - } - - for i, hop := range p.BlindedHops { - if err := encodeBlindedHop(w, hop, buf); err != nil { - return fmt.Errorf("hop %v: %w", i, err) - } - } - - return nil - } - - return tlv.NewTypeForEncodingErr(val, "*sphinx.BlindedPath") -} - -// decodeReplyPath decodes a reply path tlv. -func decodeReplyPath(r io.Reader, val interface{}, buf *[8]byte, - l uint64) error { - - // If we have the correct type, and the length exceeds the fixed header - // size (first node pubkey (33) + blinding point (33) + hop count (1) = - // 67 bytes) to accommodate at least one hop, decode the reply path. - if p, ok := val.(*sphinx.BlindedPath); ok && l > 67 { - err := tlv.DPubKey(r, &p.IntroductionPoint, buf, 33) - if err != nil { - return fmt.Errorf("decode first id: %w", err) - } - - err = tlv.DPubKey(r, &p.BlindingPoint, buf, 33) - if err != nil { - return fmt.Errorf("decode blinding point: %w", err) - } - - var hopCount uint8 - if err := tlv.DUint8(r, &hopCount, buf, 1); err != nil { - return fmt.Errorf("decode hop count: %w", err) - } - - if hopCount == 0 { - return ErrNoHops - } - - for i := 0; i < int(hopCount); i++ { - hop := &sphinx.BlindedHopInfo{} - if err := decodeBlindedHop(r, hop, buf); err != nil { - return fmt.Errorf("decode hop: %w", err) - } - - p.BlindedHops = append(p.BlindedHops, hop) - } - - return nil - } - - return tlv.NewTypeForDecodingErr(val, "*sphinx.BlindedPath", l, l) -} - -// blindedHopSize returns the encoded size of a blinded hop. -func blindedHopSize(b *sphinx.BlindedHopInfo) uint64 { - // 33 byte pubkey + 2 bytes uint16 length + var bytes. - return uint64(33 + 2 + len(b.CipherText)) -} - -// encodeBlindedHop encodes a blinded hop tlv. -func encodeBlindedHop(w io.Writer, val interface{}, buf *[8]byte) error { - if b, ok := val.(*sphinx.BlindedHopInfo); ok { - if err := tlv.EPubKey(w, &b.BlindedNodePub, buf); err != nil { - return fmt.Errorf("encode blinded id: %w", err) - } - - dataLen := uint16(len(b.CipherText)) - if err := tlv.EUint16(w, &dataLen, buf); err != nil { - return fmt.Errorf("data len: %w", err) - } - - if err := tlv.EVarBytes(w, &b.CipherText, buf); err != nil { - return fmt.Errorf("encode encrypted data: %w", err) - } - - return nil - } - - return tlv.NewTypeForEncodingErr(val, "*sphinx.BlindedHopInfo") -} - -// decodeBlindedHop decodes a blinded hop tlv. -func decodeBlindedHop(r io.Reader, val interface{}, buf *[8]byte) error { - if b, ok := val.(*sphinx.BlindedHopInfo); ok { - err := tlv.DPubKey(r, &b.BlindedNodePub, buf, 33) - if err != nil { - return fmt.Errorf("decode blinded id: %w", err) - } - - var dataLen uint16 - err = tlv.DUint16(r, &dataLen, buf, 2) - if err != nil { - return fmt.Errorf("decode data len: %w", err) - } - - err = tlv.DVarBytes(r, &b.CipherText, buf, uint64(dataLen)) - if err != nil { - return fmt.Errorf("decode data: %w", err) - } - - return nil - } - - return tlv.NewTypeForDecodingErr(val, "*sphinx.BlindedHopInfo", 0, 0) -} diff --git a/lnwire/onion_msg_payload_test.go b/lnwire/onion_msg_payload_test.go index 3a85ec60861..6871f9926bb 100644 --- a/lnwire/onion_msg_payload_test.go +++ b/lnwire/onion_msg_payload_test.go @@ -5,15 +5,14 @@ import ( "fmt" "testing" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" "pgregory.net/rapid" ) // makeBlindedPath creates a BlindedPath with the given number of hops for -// testing. Each hop has a random blinded node pub and some cipher text. -func makeBlindedPath(t *testing.T, numHops int) *sphinx.BlindedPath { +// testing. Each hop has a random blinded node ID and some cipher text. +func makeBlindedPath(t *testing.T, numHops int) *BlindedPath { t.Helper() introKey, err := randPubKey() @@ -22,55 +21,48 @@ func makeBlindedPath(t *testing.T, numHops int) *sphinx.BlindedPath { blindingKey, err := randPubKey() require.NoError(t, err) - hops := make([]*sphinx.BlindedHopInfo, numHops) + hops := make([]BlindedHop, numHops) for i := range hops { nodePub, err := randPubKey() require.NoError(t, err) - hops[i] = &sphinx.BlindedHopInfo{ - BlindedNodePub: nodePub, - CipherText: bytes.Repeat([]byte{byte(i + 1)}, 32), - } + hops[i].BlindedNodeID = nodePub + hops[i].EncryptedData = bytes.Repeat([]byte{byte(i + 1)}, 32) } - return &sphinx.BlindedPath{ - IntroductionPoint: introKey, - BlindingPoint: blindingKey, - BlindedHops: hops, + return &BlindedPath{ + IntroductionNode: PubkeyIntro{Pubkey: introKey}, + BlindingPoint: blindingKey, + Hops: hops, } } -// assertBlindedPathEqual compares two BlindedPaths for equality, checking each -// field. -func assertBlindedPathEqual(t *testing.T, expected, - actual *sphinx.BlindedPath) { - +// assertBlindedPathEqual compares two BlindedPaths field-by-field. Direct +// require.Equal would also work, but the per-field assertions surface +// localised mismatches for easier triage. +func assertBlindedPathEqual(t *testing.T, expected, actual *BlindedPath) { t.Helper() - require.True( - t, - expected.IntroductionPoint.IsEqual(actual.IntroductionPoint), - "IntroductionPoint mismatch", + require.Equal( + t, expected.IntroductionNode, actual.IntroductionNode, + "IntroductionNode mismatch", ) - require.True( - t, expected.BlindingPoint.IsEqual(actual.BlindingPoint), + require.Equal( + t, expected.BlindingPoint, actual.BlindingPoint, "BlindingPoint mismatch", ) - require.Len(t, actual.BlindedHops, len(expected.BlindedHops)) - - for i, expectedHop := range expected.BlindedHops { - actualHop := actual.BlindedHops[i] + require.Len(t, actual.Hops, len(expected.Hops)) - require.True( - t, - expectedHop.BlindedNodePub.IsEqual( - actualHop.BlindedNodePub, - ), - "hop %d: BlindedNodePub mismatch", i, + for i := range expected.Hops { + require.Equal( + t, expected.Hops[i].BlindedNodeID, + actual.Hops[i].BlindedNodeID, + "hop %d: BlindedNodeID mismatch", i, ) require.Equal( - t, expectedHop.CipherText, actualHop.CipherText, - "hop %d: CipherText mismatch", i, + t, expected.Hops[i].EncryptedData, + actual.Hops[i].EncryptedData, + "hop %d: EncryptedData mismatch", i, ) } } @@ -112,6 +104,29 @@ func TestOnionMessagePayloadRoundTrip(t *testing.T) { require.Empty(t, decoded.FinalHopTLVs) }) + t.Run("sciddir intro reply path", func(t *testing.T) { + t.Parallel() + + path := makeBlindedPath(t, 2) + path.IntroductionNode = SciddirIntro{ + Direction: 0x01, + SCID: [scidLen]byte{ + 0x00, 0x11, 0x22, 0x33, + 0x44, 0x55, 0x66, 0x77, + }, + } + + original := &OnionMessagePayload{ReplyPath: path} + + decoded := encodeAndDecode(t, original) + + require.NotNil(t, decoded.ReplyPath) + require.IsType( + t, SciddirIntro{}, decoded.ReplyPath.IntroductionNode, + ) + assertBlindedPathEqual(t, original.ReplyPath, decoded.ReplyPath) + }) + t.Run("only encrypted data", func(t *testing.T) { t.Parallel() @@ -351,15 +366,15 @@ func TestOnionMessagePayloadEncodeReplyPathNoHops(t *testing.T) { require.NoError(t, err) payload := &OnionMessagePayload{ - ReplyPath: &sphinx.BlindedPath{ - IntroductionPoint: introKey, - BlindingPoint: blindingKey, - BlindedHops: nil, + ReplyPath: &BlindedPath{ + IntroductionNode: PubkeyIntro{Pubkey: introKey}, + BlindingPoint: blindingKey, + Hops: nil, }, } _, err = payload.Encode() - require.ErrorIs(t, err, ErrNoHops) + require.ErrorIs(t, err, ErrEmptyBlindedPath) } // TestOnionMessagePayloadEmpty tests that an empty payload roundtrips @@ -442,35 +457,9 @@ func TestOnionMessagePayloadRoundTripQuickCheck(t *testing.T) { require.Nil(t, decoded.ReplyPath) } else { require.NotNil(t, decoded.ReplyPath) - require.True( - t, - original.ReplyPath.IntroductionPoint.IsEqual( - decoded.ReplyPath.IntroductionPoint, - ), - ) - require.True( - t, - original.ReplyPath.BlindingPoint.IsEqual( - decoded.ReplyPath.BlindingPoint, - ), - ) - require.Len( - t, decoded.ReplyPath.BlindedHops, - len(original.ReplyPath.BlindedHops), + require.Equal( + t, original.ReplyPath, decoded.ReplyPath, ) - for i, hop := range original.ReplyPath.BlindedHops { - dHop := decoded.ReplyPath.BlindedHops[i] - require.True( - t, - hop.BlindedNodePub.IsEqual( - dHop.BlindedNodePub, - ), - ) - require.Equal( - t, hop.CipherText, - dHop.CipherText, - ) - } } // Verify encrypted data. diff --git a/lnwire/test_utils.go b/lnwire/test_utils.go index 602724a9dd6..ae70aeeb543 100644 --- a/lnwire/test_utils.go +++ b/lnwire/test_utils.go @@ -9,7 +9,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/fn/v2" "github.com/stretchr/testify/require" "pgregory.net/rapid" @@ -59,30 +58,48 @@ func RandPubKey(t *rapid.T) *btcec.PublicKey { return pub } -// RandBlindedPath generates a random blinded path with 1-5 hops. -func RandBlindedPath(t *rapid.T) *sphinx.BlindedPath { - introKey := RandPubKey(t) - blindingKey := RandPubKey(t) +// RandBlindedPath generates a random blinded path with 1-5 hops, alternating +// between the pubkey and sciddir introduction-node variants per draw. +func RandBlindedPath(t *rapid.T) *BlindedPath { + useSciddir := rapid.Bool().Draw(t, "introIsSciddir") + + var intro IntroductionNode + if useSciddir { + var scid [scidLen]byte + copy(scid[:], rapid.SliceOfN( + rapid.Byte(), scidLen, scidLen, + ).Draw(t, "introScid")) + + intro = SciddirIntro{ + Direction: byte( + rapid.IntRange(0, 1).Draw(t, "introDir"), + ), + SCID: scid, + } + } else { + intro = PubkeyIntro{Pubkey: RandPubKey(t)} + } + + blindingPoint := RandPubKey(t) numHops := rapid.IntRange(1, 5).Draw(t, "numBlindedHops") - hops := make([]*sphinx.BlindedHopInfo, numHops) + hops := make([]BlindedHop, numHops) for i := range hops { cipherLen := rapid.IntRange(1, 128).Draw( t, fmt.Sprintf("cipherLen-%d", i), ) - hops[i] = &sphinx.BlindedHopInfo{ - BlindedNodePub: RandPubKey(t), - CipherText: rapid.SliceOfN( - rapid.Byte(), cipherLen, cipherLen, - ).Draw(t, fmt.Sprintf("cipherText-%d", i)), - } + hops[i].BlindedNodeID = RandPubKey(t) + + hops[i].EncryptedData = rapid.SliceOfN( + rapid.Byte(), cipherLen, cipherLen, + ).Draw(t, fmt.Sprintf("cipherText-%d", i)) } - return &sphinx.BlindedPath{ - IntroductionPoint: introKey, - BlindingPoint: blindingKey, - BlindedHops: hops, + return &BlindedPath{ + IntroductionNode: intro, + BlindingPoint: blindingPoint, + Hops: hops, } } diff --git a/onionmessage/onion_endpoint.go b/onionmessage/onion_endpoint.go index f6a32d2ebd9..0c9829e2463 100644 --- a/onionmessage/onion_endpoint.go +++ b/onionmessage/onion_endpoint.go @@ -1,7 +1,7 @@ package onionmessage import ( - sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" ) @@ -26,7 +26,7 @@ type OnionMessageUpdate struct { CustomRecords record.CustomSet // ReplyPath contains the reply path information for the onion message. - ReplyPath *sphinx.BlindedPath + ReplyPath *lnwire.BlindedPath // EncryptedRecipientData contains the encrypted recipient data for the // onion message, created by the creator of the blinded route. This is diff --git a/routing/route/blindedroute.go b/routing/route/blindedroute.go index 2b8120ad6e2..35ad07031a5 100644 --- a/routing/route/blindedroute.go +++ b/routing/route/blindedroute.go @@ -13,7 +13,7 @@ import ( // payloads used to encoding the routing data for each hop in the route. This // method also accepts final hop payloads. func OnionMessageBlindedPathToSphinxPath(blindedPath *sphinx.BlindedPath, - replyPath *sphinx.BlindedPath, finalHopTLVs []*lnwire.FinalHopTLV) ( + replyPath *lnwire.BlindedPath, finalHopTLVs []*lnwire.FinalHopTLV) ( *sphinx.PaymentPath, error) { var path sphinx.PaymentPath diff --git a/rpcserver.go b/rpcserver.go index 491bd8a1426..6cbe86291b5 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -9570,15 +9570,19 @@ func (r *rpcServer) SubscribeOnionMessages( //nolint:ll if oMsg.ReplyPath != nil { - bp.IntroductionNode = oMsg.ReplyPath.IntroductionPoint.SerializeCompressed() + // TODO(bolt12): resolve sciddir intros via + // sciddirResolver so this field is uniformly a + // 33-byte pubkey? + bp.IntroductionNode = oMsg.ReplyPath.IntroductionNode.Bytes() bp.BlindingPoint = oMsg.ReplyPath.BlindingPoint.SerializeCompressed() - for _, hop := range oMsg.ReplyPath.BlindedHops { - rpcHop := &lnrpc.BlindedHop{ - BlindedNode: hop.BlindedNodePub.SerializeCompressed(), - EncryptedData: hop.CipherText, - } - bp.BlindedHops = append(bp.BlindedHops, rpcHop) + for _, hop := range oMsg.ReplyPath.Hops { + bp.BlindedHops = append( + bp.BlindedHops, &lnrpc.BlindedHop{ + BlindedNode: hop.BlindedNodeID.SerializeCompressed(), + EncryptedData: hop.EncryptedData, + }, + ) } } From c46db79536d64c264cae9f8dbcfc0b0f519feee0 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Wed, 6 May 2026 16:39:59 +0200 Subject: [PATCH 05/11] bolt12: add chains TLV subtype Introduce the ChainsRecord subtype used by the offer_chains and invoice_chains TLV fields. Decoding caps the count at maxOfferChains to bound allocation. --- bolt12/doc.go | 19 +++++ bolt12/subtypes.go | 87 +++++++++++++++++++++++ bolt12/subtypes_test.go | 154 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 260 insertions(+) create mode 100644 bolt12/doc.go create mode 100644 bolt12/subtypes.go create mode 100644 bolt12/subtypes_test.go diff --git a/bolt12/doc.go b/bolt12/doc.go new file mode 100644 index 00000000000..c58a1ced1d3 --- /dev/null +++ b/bolt12/doc.go @@ -0,0 +1,19 @@ +// Package bolt12 implements encoding, decoding, and validation for BOLT 12 +// Offers, Invoice Requests, and Invoices. It provides a pure codec library +// with no LND daemon dependencies. +// +// BOLT 12 messages use TLV streams encoded with a checksumless bech32 variant +// and signed with BIP-340 Schnorr signatures over a Merkle tree of TLV fields. +// +// Human-readable prefixes: +// - lno: Offer +// - lnr: Invoice Request +// - lni: Invoice +// +// # Codec Contract +// +// Encode validates before serialising and refuses to emit bytes that would fail +// the writer requirements, invalid bytes are unrepresentable on the wire. +// Low-level decoders stay permissive so diagnostic and fuzz harnesses can +// inspect malformed input. +package bolt12 diff --git a/bolt12/subtypes.go b/bolt12/subtypes.go new file mode 100644 index 00000000000..dcd5e41de15 --- /dev/null +++ b/bolt12/subtypes.go @@ -0,0 +1,87 @@ +package bolt12 + +import ( + "errors" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +// ErrTooManyChains is returned when offer_chains declares more entries than +// maxOfferChains. +var ErrTooManyChains = errors.New("offer_chains exceeds maxOfferChains") + +const ( + // chainHashLen is the length of a chain hash (32 bytes). + chainHashLen = 32 + + // maxOfferChains caps decoded offer_chains entries. This is a sanity + // check to prevent excessive memory allocation and is not a protocol + // limit but a local implementation choice. + maxOfferChains = 32 +) + +// ChainsRecord holds one or more chain hashes for the offer_chains field. +type ChainsRecord struct { + Chains [][chainHashLen]byte +} + +var _ tlv.RecordProducer = (*ChainsRecord)(nil) + +// Record returns a TLV record for ChainsRecord. +func (c *ChainsRecord) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, c, + func() uint64 { + return uint64(len(c.Chains)) * chainHashLen + }, + encodeChainsRecord, + decodeChainsRecord, + ) +} + +// encodeChainsRecord writes the chain hashes in sequence, without a count +// prefix. +func encodeChainsRecord(w io.Writer, val any, _ *[8]byte) error { + c, ok := val.(*ChainsRecord) + if !ok { + return fmt.Errorf("expected *ChainsRecord, got %T", val) + } + + for _, chain := range c.Chains { + if _, err := w.Write(chain[:]); err != nil { + return err + } + } + + return nil +} + +// decodeChainsRecord caps the count at maxOfferChains to bound allocation. +func decodeChainsRecord(r io.Reader, val any, _ *[8]byte, l uint64) error { + c, ok := val.(*ChainsRecord) + if !ok { + return fmt.Errorf("expected *ChainsRecord, got %T", val) + } + + if l%chainHashLen != 0 { + return fmt.Errorf("chains length %d not a multiple of %d", l, + chainHashLen) + } + + numChains := l / chainHashLen + if numChains > maxOfferChains { + return fmt.Errorf("%w: %d > %d", ErrTooManyChains, numChains, + maxOfferChains) + } + + c.Chains = make([][chainHashLen]byte, numChains) + for i := range c.Chains { + if _, err := io.ReadFull(r, c.Chains[i][:]); err != nil { + return err + } + } + + return nil +} diff --git a/bolt12/subtypes_test.go b/bolt12/subtypes_test.go new file mode 100644 index 00000000000..ceecd14a084 --- /dev/null +++ b/bolt12/subtypes_test.go @@ -0,0 +1,154 @@ +package bolt12 + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestDecodeChainsRecord pins the chain-array decoder's structural rejections. +func TestDecodeChainsRecord(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data []byte + wantErr error + wantMsg string + }{ + { + name: "length not multiple of 32", + data: append( + bytes.Repeat( + []byte{0xaa}, chainHashLen, + ), + 187, + ), + wantMsg: "not a multiple of", + }, + { + name: "exceeds cap", + data: bytes.Repeat( + []byte{0x00}, (maxOfferChains+1)*chainHashLen, + ), + wantErr: ErrTooManyChains, + }, + } + + for _, tc := range tests { + t.Run( + tc.name, + func(t *testing.T) { + t.Parallel() + + var c ChainsRecord + err := decodeChainsRecord( + bytes.NewReader(tc.data), &c, + new([8]byte), + uint64( + len(tc.data), + ), + ) + require.Error(t, err) + + if tc.wantErr != nil { + require.ErrorIs(t, err, tc.wantErr) + } + + if tc.wantMsg != "" { + require.Contains( + t, err.Error(), tc.wantMsg, + ) + } + }, + ) + } +} + +// TestChainsRecordRoundTrip pins decode→re-encode against the BOLT 12 offer +// test vectors. +func TestChainsRecordRoundTrip(t *testing.T) { + t.Parallel() + + // bitcoinHash is the bitcoin mainnet genesis hash hex-decoded into a + // fixed array. Defined locally so the test does not depend on constants + // introduced by later commits. + bitcoinHashHex := "6fe28c0ab6f1b372c1a6a246ae63f74f931e8365" + + "e15a089c68d6190000000000" + + var bitcoinHash [chainHashLen]byte + bitcoinHashBytes, err := hex.DecodeString(bitcoinHashHex) + require.NoError(t, err) + copy(bitcoinHash[:], bitcoinHashBytes) + + tests := []struct { + name string + // hex is the on-wire bytes of the offer_chains TLV value + // (concatenated 32-byte chain hashes), copied from + // bolt12/offers-test.json. + hex string + wantLen int + wantHash [chainHashLen]byte + }{ + { + name: "single testnet chain", + hex: "43497fd7f826957108f4a30fd9cec3ae" + + "ba79972084e90ead01ea330900000000", + wantLen: 1, + }, + { + name: "single bitcoin chain", + hex: bitcoinHashHex, + wantLen: 1, + wantHash: bitcoinHash, + }, + { + name: "two chains liquidv1 then bitcoin", + hex: "1466275836220db2944ca059a3a10ef6fd2ea684b" + + "0688d2c379296888a206003" + bitcoinHashHex, + wantLen: 2, + // Second chain in the list is bitcoin mainnet. + wantHash: bitcoinHash, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString(tc.hex) + require.NoError(t, err) + + var c ChainsRecord + err = decodeChainsRecord( + bytes.NewReader(data), &c, new([8]byte), + uint64( + len(data), + ), + ) + require.NoError(t, err) + require.Len(t, c.Chains, tc.wantLen) + + // Cross-check the canonical bitcoin chain hash where + // the row knows which slot it lives in. + var zero [chainHashLen]byte + if tc.wantHash != zero { + idx := tc.wantLen - 1 + require.Equal( + t, tc.wantHash, c.Chains[idx], + "bitcoin hash mismatch in slot %d", + idx, + ) + } + + var buf bytes.Buffer + require.NoError( + t, encodeChainsRecord(&buf, &c, new([8]byte)), + ) + + require.Equal(t, data, buf.Bytes()) + }) + } +} From 49c26ac086ea78d5e9ae70f7d2416f74e7453f6e Mon Sep 17 00:00:00 2001 From: bitromortac Date: Fri, 8 May 2026 10:05:35 +0200 Subject: [PATCH 06/11] bolt12: add Offer message struct with TLV codec The Offer struct models a long-lived, reusable BOLT 12 payment template. It defines TLV fields as optional records and exposes Encode/DecodeOffer for round-trip serialization. The struct implements lnwire.PureTLVMessage; AllRecords filters the decoded TypeMap through bolt12InUnsignedRange to derive any signed-range extras the encoder must re-emit, keeping offer_id and the Merkle root stable across encoders that understand a wider set of even/odd extensions. --- bolt12/decode.go | 29 ++++++++ bolt12/helpers_test.go | 16 ++++ bolt12/offer.go | 163 +++++++++++++++++++++++++++++++++++++++++ bolt12/offer_test.go | 49 +++++++++++++ bolt12/pure_tlv.go | 38 ++++++++++ bolt12/tlv_types.go | 22 ++++++ 6 files changed, 317 insertions(+) create mode 100644 bolt12/decode.go create mode 100644 bolt12/helpers_test.go create mode 100644 bolt12/offer.go create mode 100644 bolt12/offer_test.go create mode 100644 bolt12/pure_tlv.go create mode 100644 bolt12/tlv_types.go diff --git a/bolt12/decode.go b/bolt12/decode.go new file mode 100644 index 00000000000..4307455b808 --- /dev/null +++ b/bolt12/decode.go @@ -0,0 +1,29 @@ +package bolt12 + +import ( + "bytes" + "fmt" + + "github.com/lightningnetwork/lnd/tlv" +) + +// decodeStream runs a single typed-stream pass over data and returns the +// canonical TypeMap. Records may be passed in any order; NewStream requires +// them sorted, so SortRecords runs first. +func decodeStream(data []byte, records ...tlv.Record) (tlv.TypeMap, error) { + tlv.SortRecords(records) + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, fmt.Errorf("create stream: %w", err) + } + + typeMap, err := stream.DecodeWithParsedTypesP2P( + bytes.NewReader(data), + ) + if err != nil { + return nil, fmt.Errorf("decode stream: %w", err) + } + + return typeMap, nil +} diff --git a/bolt12/helpers_test.go b/bolt12/helpers_test.go new file mode 100644 index 00000000000..f301ad21b20 --- /dev/null +++ b/bolt12/helpers_test.go @@ -0,0 +1,16 @@ +package bolt12 + +import ( + "bytes" + + "github.com/btcsuite/btcd/btcec/v2" +) + +// bobKey returns the deterministic spec test key for Bob, whose 32-byte scalar +// is 0x42 repeated. Used across signature and round-trip tests so the same key +// is not reconstructed in every callsite. +func bobKey() (*btcec.PrivateKey, *btcec.PublicKey) { + priv, pub := btcec.PrivKeyFromBytes(bytes.Repeat([]byte{0x42}, 32)) + + return priv, pub +} diff --git a/bolt12/offer.go b/bolt12/offer.go new file mode 100644 index 00000000000..c203fbbadb9 --- /dev/null +++ b/bolt12/offer.go @@ -0,0 +1,163 @@ +package bolt12 + +import ( + "bytes" + "fmt" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// Offer represents a BOLT 12 offer message. An offer is a long-lived, reusable +// payment template that can generate multiple invoices. +type Offer struct { + // OfferChains specifies which chains this offer is valid for. If + // absent, bitcoin is implied. + OfferChains tlv.OptionalRecordT[tlv.TlvType2, ChainsRecord] + + // OfferMetadata is opaque data set by the offer creator for its own + // use. + OfferMetadata tlv.OptionalRecordT[tlv.TlvType4, tlv.Blob] + + // OfferCurrency is the ISO 4217 currency code for the offer amount, if + // the amount is not in the chain's native unit. + OfferCurrency tlv.OptionalRecordT[tlv.TlvType6, tlv.Blob] + + // OfferAmount is the amount expected per item, encoded as a tu64. The + // unit depends on OfferCurrency (msat if absent). + OfferAmount tlv.OptionalRecordT[tlv.TlvType8, TUint64] + + // OfferDescription is a UTF-8 description of the purpose of the + // payment. + OfferDescription tlv.OptionalRecordT[tlv.TlvType10, tlv.Blob] + + // OfferFeatures is the feature bit vector for this offer. + OfferFeatures tlv.OptionalRecordT[tlv.TlvType12, + lnwire.RawFeatureVector] + + // OfferAbsoluteExpiry is the time (seconds since epoch) after which the + // offer should not be used, encoded as a tu64. + OfferAbsoluteExpiry tlv.OptionalRecordT[tlv.TlvType14, TUint64] + + // OfferPaths contains one or more blinded paths to the offer issuer. + OfferPaths tlv.OptionalRecordT[tlv.TlvType16, lnwire.BlindedPaths] + + // OfferIssuer is a UTF-8 string identifying the issuer. + OfferIssuer tlv.OptionalRecordT[tlv.TlvType18, tlv.Blob] + + // OfferQuantityMax is the maximum number of items that can be requested + // in a single invoice, encoded as a tu64. A value of 0 means unlimited. + OfferQuantityMax tlv.OptionalRecordT[tlv.TlvType20, TUint64] + + // OfferIssuerID is the public key of the offer issuer. The codec + // parses the 33-byte SEC1 compressed point on decode, so a struct + // holding a key has already passed both the length and on-curve + // checks. + OfferIssuerID tlv.OptionalRecordT[tlv.TlvType22, *btcec.PublicKey] + + // decodedTLVs is the canonical TypeMap produced by decoding this offer. + // Handled types map to nil; unhandled types map to their value bytes. + // Encoding and validation both derive their view from this single field + // so they cannot drift apart, and so signed-range extras the decoder + // did not understand are re-emitted on encode and preserve offer_id. + decodedTLVs tlv.TypeMap +} + +var _ lnwire.PureTLVMessage = (*Offer)(nil) + +// AllRecords returns the canonical sorted record list for this offer, merging +// the typed records with any extra signed-range fields that the decoder +// preserved. +func (o *Offer) AllRecords() []tlv.Record { + return allRecordsFromTypeMap( + o.allRecordProducers(), o.decodedTLVs, + ) +} + +// allRecordProducers returns record producers for every set optional field, in +// declaration order. +func (o *Offer) allRecordProducers() []tlv.RecordProducer { + var p []tlv.RecordProducer + + lnwire.AddOpt(&p, o.OfferChains) + lnwire.AddOpt(&p, o.OfferMetadata) + lnwire.AddOpt(&p, o.OfferCurrency) + lnwire.AddOpt(&p, o.OfferAmount) + lnwire.AddOpt(&p, o.OfferDescription) + lnwire.AddOpt(&p, o.OfferFeatures) + lnwire.AddOpt(&p, o.OfferAbsoluteExpiry) + lnwire.AddOpt(&p, o.OfferPaths) + lnwire.AddOpt(&p, o.OfferIssuer) + lnwire.AddOpt(&p, o.OfferQuantityMax) + lnwire.AddOpt(&p, o.OfferIssuerID) + + return p +} + +// Encode serialises the offer into a canonical TLV byte stream. +func (o *Offer) Encode() ([]byte, error) { + var buf bytes.Buffer + if err := lnwire.EncodePureTLVMessage(o, &buf); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// decodeOffer parses a TLV byte stream into an Offer. Decoding is permissive — +// the spec writer requirements are not enforced here, so callers that need a +// valid offer must run ValidateOfferRead. Unknown TLVs are preserved on the +// returned offer so a later Encode can re-emit signed-range extras and keep +// offer_id stable. +func decodeOffer(data []byte) (*Offer, error) { + var o Offer + + // Prepare zero-valued records for all optional fields so the TLV + // decoder can populate them. + chains := tlv.ZeroRecordT[tlv.TlvType2, ChainsRecord]() + metadata := tlv.ZeroRecordT[tlv.TlvType4, tlv.Blob]() + currency := tlv.ZeroRecordT[tlv.TlvType6, tlv.Blob]() + amount := tlv.ZeroRecordT[tlv.TlvType8, TUint64]() + desc := tlv.ZeroRecordT[tlv.TlvType10, tlv.Blob]() + features := tlv.ZeroRecordT[tlv.TlvType12, lnwire.RawFeatureVector]() + expiry := tlv.ZeroRecordT[tlv.TlvType14, TUint64]() + paths := tlv.ZeroRecordT[tlv.TlvType16, lnwire.BlindedPaths]() + issuer := tlv.ZeroRecordT[tlv.TlvType18, tlv.Blob]() + qtyMax := tlv.ZeroRecordT[tlv.TlvType20, TUint64]() + issuerID := tlv.ZeroRecordT[tlv.TlvType22, *btcec.PublicKey]() + + tm, err := decodeStream( + data, + chains.Record(), + metadata.Record(), + currency.Record(), + amount.Record(), + desc.Record(), + features.Record(), + expiry.Record(), + paths.Record(), + issuer.Record(), + qtyMax.Record(), + issuerID.Record(), + ) + if err != nil { + return nil, fmt.Errorf("decode offer: %w", err) + } + + lnwire.SetOptFromMap(tm, &o.OfferChains, chains) + lnwire.SetOptFromMap(tm, &o.OfferMetadata, metadata) + lnwire.SetOptFromMap(tm, &o.OfferCurrency, currency) + lnwire.SetOptFromMap(tm, &o.OfferAmount, amount) + lnwire.SetOptFromMap(tm, &o.OfferDescription, desc) + lnwire.SetOptFromMap(tm, &o.OfferFeatures, features) + lnwire.SetOptFromMap(tm, &o.OfferAbsoluteExpiry, expiry) + lnwire.SetOptFromMap(tm, &o.OfferPaths, paths) + lnwire.SetOptFromMap(tm, &o.OfferIssuer, issuer) + lnwire.SetOptFromMap(tm, &o.OfferQuantityMax, qtyMax) + lnwire.SetOptFromMap(tm, &o.OfferIssuerID, issuerID) + + o.decodedTLVs = tm + + return &o, nil +} diff --git a/bolt12/offer_test.go b/bolt12/offer_test.go new file mode 100644 index 00000000000..2a9ae325bce --- /dev/null +++ b/bolt12/offer_test.go @@ -0,0 +1,49 @@ +package bolt12 + +import ( + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestOfferRoundTrip pins encode→decode→re-encode for an Offer with a +// representative subset of optional fields. A byte-identical re-encode is the +// invariant that keeps offer_id stable across the codec boundary. +func TestOfferRoundTrip(t *testing.T) { + t.Parallel() + + desc := tlv.Blob("coffee") + issuer := tlv.Blob("alice") + _, bobPub := bobKey() + + o := &Offer{ + OfferAmount: tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8](TUint64(1500)), + ), + OfferDescription: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType10](desc), + ), + OfferIssuer: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType18](issuer), + ), + OfferIssuerID: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType22](bobPub), + ), + } + + encoded, err := o.Encode() + require.NoError(t, err) + require.NotEmpty(t, encoded) + + decoded, err := decodeOffer(encoded) + require.NoError(t, err) + + require.Equal(t, TUint64(1500), decoded.OfferAmount.UnwrapOrFailV(t)) + require.Equal(t, desc, decoded.OfferDescription.UnwrapOrFailV(t)) + require.Equal(t, issuer, decoded.OfferIssuer.UnwrapOrFailV(t)) + + reencoded, err := decoded.Encode() + require.NoError(t, err) + require.Equal(t, encoded, reencoded) +} diff --git a/bolt12/pure_tlv.go b/bolt12/pure_tlv.go new file mode 100644 index 00000000000..93a00ebc905 --- /dev/null +++ b/bolt12/pure_tlv.go @@ -0,0 +1,38 @@ +package bolt12 + +import ( + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// bolt12InUnsignedRange reports whether a TLV type is excluded from the BOLT 12 +// Merkle tree. The spec reserves types 240-1000 for signature TLVs (the BIP-340 +// Schnorr signatures over the tree itself); every other allowed type sits in +// the signed range. +func bolt12InUnsignedRange(t tlv.Type) bool { + return t >= 240 && t <= 1000 +} + +// allRecordsFromTypeMap merges the typed-record producers with the signed-range +// subset of the supplied TypeMap (preserved unknown TLVs) and returns the +// canonical sorted record list. The signed-range subset is derived on demand +// from the same TypeMap that drives the validators, so the two views cannot +// drift apart. +func allRecordsFromTypeMap(producers []tlv.RecordProducer, + tm tlv.TypeMap) []tlv.Record { + + if len(tm) > 0 { + extra := lnwire.ExtraSignedFieldsFromTypeMapFn( + tm, bolt12InUnsignedRange, + ) + if len(extra) > 0 { + producers = append( + producers, lnwire.RecordsAsProducers( + tlv.MapToRecords(extra), + )..., + ) + } + } + + return lnwire.ProduceRecordsSorted(producers...) +} diff --git a/bolt12/tlv_types.go b/bolt12/tlv_types.go new file mode 100644 index 00000000000..b956477c112 --- /dev/null +++ b/bolt12/tlv_types.go @@ -0,0 +1,22 @@ +package bolt12 + +import ( + "github.com/lightningnetwork/lnd/tlv" +) + +// TUint64 is a uint64 that serializes using truncated encoding (tu64) +// as required by BOLT 12. Leading zero bytes are omitted. +type TUint64 uint64 + +// Record returns a TLV record using truncated uint64 encoding. +// +// NOTE: This implements the tlv.RecordProducer interface. +func (t *TUint64) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, (*uint64)(t), + func() uint64 { + return tlv.SizeTUint64(uint64(*t)) + }, + tlv.ETUint64, tlv.DTUint64, + ) +} From 4d7cf566d0ef92663f9caeac97c3a9d756c85447 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Wed, 6 May 2026 16:52:12 +0200 Subject: [PATCH 07/11] bolt12: validate Offer per BOLT 12 reader/writer requirements ValidateOfferRead and ValidateOfferWrite enforce the codec-side portion of the BOLT 12 offer reader and writer requirements. Reader rules cover TLV range, even-feature-bit rejection, chain mismatch, dependency rules between offer_amount/description/currency, missing issuer identity, zero-hop blinded paths, and offer expiry. Writer rules mirror the same dependency and identity guards plus a defense-in-depth empty- offer_chains rejection. offer_currency is validated against the ISO 4217 registry via golang.org/x/text/currency (now a direct dependency); offer_issuer_id is verified to be an on-curve SEC1 compressed point on both read and write paths. Encode invokes Validate so invalid bytes never reach the wire. --- bolt12/helpers_test.go | 8 + bolt12/offer.go | 4 + bolt12/pure_tlv.go | 14 ++ bolt12/validate.go | 397 +++++++++++++++++++++++++++++++++++++++ bolt12/validate_test.go | 402 ++++++++++++++++++++++++++++++++++++++++ go.mod | 2 +- 6 files changed, 826 insertions(+), 1 deletion(-) create mode 100644 bolt12/validate.go create mode 100644 bolt12/validate_test.go diff --git a/bolt12/helpers_test.go b/bolt12/helpers_test.go index f301ad21b20..78bbdfdffd2 100644 --- a/bolt12/helpers_test.go +++ b/bolt12/helpers_test.go @@ -14,3 +14,11 @@ func bobKey() (*btcec.PrivateKey, *btcec.PublicKey) { return priv, pub } + +// aliceKey returns the deterministic spec test key for Alice, whose 32-byte +// scalar is 0x41 repeated. +func aliceKey() (*btcec.PrivateKey, *btcec.PublicKey) { + priv, pub := btcec.PrivKeyFromBytes(bytes.Repeat([]byte{0x41}, 32)) + + return priv, pub +} diff --git a/bolt12/offer.go b/bolt12/offer.go index c203fbbadb9..dfe8cd60f4e 100644 --- a/bolt12/offer.go +++ b/bolt12/offer.go @@ -97,6 +97,10 @@ func (o *Offer) allRecordProducers() []tlv.RecordProducer { // Encode serialises the offer into a canonical TLV byte stream. func (o *Offer) Encode() ([]byte, error) { + if err := ValidateOfferWrite(o); err != nil { + return nil, fmt.Errorf("validate offer: %w", err) + } + var buf bytes.Buffer if err := lnwire.EncodePureTLVMessage(o, &buf); err != nil { return nil, err diff --git a/bolt12/pure_tlv.go b/bolt12/pure_tlv.go index 93a00ebc905..d20022b9281 100644 --- a/bolt12/pure_tlv.go +++ b/bolt12/pure_tlv.go @@ -1,6 +1,8 @@ package bolt12 import ( + "slices" + "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" ) @@ -36,3 +38,15 @@ func allRecordsFromTypeMap(producers []tlv.RecordProducer, return lnwire.ProduceRecordsSorted(producers...) } + +// sortedTypes returns the keys of tm in ascending order. Validators iterate the +// result for deterministic out-of-range and unknown-even error messages. +func sortedTypes(tm tlv.TypeMap) []tlv.Type { + out := make([]tlv.Type, 0, len(tm)) + for t := range tm { + out = append(out, t) + } + slices.Sort(out) + + return out +} diff --git a/bolt12/validate.go b/bolt12/validate.go new file mode 100644 index 00000000000..2ed93196e22 --- /dev/null +++ b/bolt12/validate.go @@ -0,0 +1,397 @@ +package bolt12 + +import ( + "errors" + "fmt" + "slices" + "time" + "unicode/utf8" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" + "golang.org/x/text/currency" +) + +var ( + // ErrOutOfRangeType is returned when a TLV type falls outside the + // allowed offer ranges (1-79 and 1000000000-1999999999). + ErrOutOfRangeType = errors.New("TLV type outside allowed range") + + // ErrUnknownEvenType is returned when an unknown even TLV type is + // present in an allowed range. Per BOLT 1, even types are + // must-understand: if the reader does not recognise the type, it MUST + // reject the message rather than silently ignoring the field. + ErrUnknownEvenType = errors.New("unknown even TLV type") + + // ErrUnknownEvenFeature is returned when an unknown even feature + // bit is set. + ErrUnknownEvenFeature = errors.New("unknown even feature bit set") + + // ErrMissingDescription is returned when offer_amount is set but + // offer_description is absent. + ErrMissingDescription = errors.New( + "offer_amount set without offer_description", + ) + + // ErrCurrencyWithoutAmount is returned when offer_currency is set + // but offer_amount is absent. + ErrCurrencyWithoutAmount = errors.New( + "offer_currency set without offer_amount", + ) + + // ErrZeroAmount is returned when offer_amount is set to zero. The spec + // requires a present offer_amount to be strictly greater than zero so + // that a zero-value cannot masquerade as "no minimum required". + ErrZeroAmount = errors.New("offer_amount must be greater than zero") + + // ErrEmptyBlindedPaths is returned when a blinded paths field is + // present on a BOLT 12 message but its list of paths is empty. The + // spec writer requirements treat "present" as implying at least one + // usable path. + ErrEmptyBlindedPaths = errors.New("blinded paths field present but " + + "empty") + + // ErrNoIssuerIdentity is returned when neither offer_issuer_id + // nor offer_paths is set. + ErrNoIssuerIdentity = errors.New( + "neither offer_issuer_id nor offer_paths set", + ) + + // ErrOfferExpired is returned when the current time is after + // offer_absolute_expiry. + ErrOfferExpired = errors.New("offer has expired") + + // ErrEmptyChains is returned when offer_chains is present but + // contains no entries. + ErrEmptyChains = errors.New( + "offer_chains present but empty", + ) + + // ErrUnsupportedChain is returned when offer_chains does not + // contain our active chain. + ErrUnsupportedChain = errors.New( + "offer does not support our chain", + ) + + // ErrInvalidUTF8 is returned when a UTF-8 field contains invalid + // sequences. + ErrInvalidUTF8 = errors.New("invalid UTF-8") + + // ErrInvalidCurrency is returned when offer_currency is not a valid ISO + // 4217 code. + ErrInvalidCurrency = errors.New("invalid offer_currency") +) + +// offerAllowedRange returns true if the TLV type falls within the allowed +// ranges for offer messages: 1-79 and 1000000000-1999999999. +func offerAllowedRange(typ tlv.Type) bool { + return (typ >= 1 && typ <= 79) || + (typ >= 1000000000 && typ <= 1999999999) +} + +// isKnownOfferTLVType returns true for TLV types that are defined in the offer +// spec (even types 2-22). +func isKnownOfferTLVType(typ tlv.Type) bool { + switch typ { + case 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22: + return true + default: + return false + } +} + +// ValidateOfferRead validates an offer per the BOLT 12 offer reader +// requirements. The now parameter is used for expiry checks and can be +// overridden in tests. activeChain is required: per spec, absent offer_chains +// defaults to Bitcoin mainnet, and the reader must reject offers that do not +// list a chain it operates on. Pass the genesis hash of the chain the receiver +// is willing to settle on. +func ValidateOfferRead(o *Offer, now time.Time, activeChain [32]byte) error { + // Check TLV types are in allowed range and that unknown even types are + // rejected (even = must-understand). + for _, t := range sortedTypes(o.decodedTLVs) { + if !offerAllowedRange(t) { + return fmt.Errorf("%w: type %d", ErrOutOfRangeType, t) + } + + if !isKnownOfferTLVType(t) && t%2 == 0 { + return fmt.Errorf("%w: type %d", ErrUnknownEvenType, t) + } + } + + // Check for unknown even feature bits. + if err := checkFeatures(o.OfferFeatures); err != nil { + return err + } + + // offer_chains present but empty. + var chainsEmpty bool + o.OfferChains.WhenSome( + func(r tlv.RecordT[tlv.TlvType2, ChainsRecord]) { + if len(r.Val.Chains) == 0 { + chainsEmpty = true + } + }, + ) + if chainsEmpty { + return ErrEmptyChains + } + + // Validate the offer's chain against the active chain. An absent + // offer_chains TLV means "Bitcoin mainnet" per spec, normalised by + // getOfferChains. + offerChains := getOfferChains(o) + found := slices.Contains(offerChains, activeChain) + if !found { + return ErrUnsupportedChain + } + + // offer_amount set requires offer_description. + hasAmount := o.OfferAmount.IsSome() + if hasAmount && !o.OfferDescription.IsSome() { + return ErrMissingDescription + } + + // offer_amount, if set, must be strictly greater than zero. + if err := checkAmountPositive(o.OfferAmount); err != nil { + return err + } + + // offer_currency requires offer_amount. + if o.OfferCurrency.IsSome() && !hasAmount { + return ErrCurrencyWithoutAmount + } + + // Must have either offer_issuer_id or offer_paths. + if !o.OfferIssuerID.IsSome() && !o.OfferPaths.IsSome() { + return ErrNoIssuerIdentity + } + + // Check blinded paths have at least one hop. + if err := checkBlindedPaths(o.OfferPaths); err != nil { + return err + } + + // Expiry check. A present-but-zero offer_absolute_expiry historically + // meant "no expiry" but that conflicts with the spec (zero is a valid + // past timestamp); treat it as already expired rather than ambiguous so + // a misuse fails closed. + var ( + expiry uint64 + hasExpiry bool + ) + o.OfferAbsoluteExpiry.WhenSome( + func(r tlv.RecordT[tlv.TlvType14, TUint64]) { + expiry = uint64(r.Val) + hasExpiry = true + }, + ) + if hasExpiry && uint64(now.Unix()) > expiry { + return ErrOfferExpired + } + + // Validate UTF-8 fields. + if err := checkUTF8(o.OfferCurrency, "offer_currency"); err != nil { + return err + } + + if err := checkUTF8( + o.OfferDescription, "offer_description", + ); err != nil { + return err + } + + if err := checkUTF8(o.OfferIssuer, "offer_issuer"); err != nil { + return err + } + + if err := checkISO4217(o.OfferCurrency); err != nil { + return err + } + + return nil +} + +// bitcoinMainnetGenesisHash is the genesis hash for Bitcoin mainnet, used as +// the default when offer_chains is absent per the spec. +var bitcoinMainnetGenesisHash = [32]byte(*chaincfg.MainNetParams.GenesisHash) + +// getOfferChains returns the chains an offer is valid for. If offer_chains is +// absent, the spec defaults to Bitcoin mainnet. +func getOfferChains(o *Offer) [][32]byte { + chains := fn.MapOptionZ( + o.OfferChains.ValOpt(), + func(r ChainsRecord) [][32]byte { return r.Chains }, + ) + + if len(chains) == 0 { + chains = [][32]byte{bitcoinMainnetGenesisHash} + } + + return chains +} + +// ValidateOfferWrite validates an offer per the BOLT 12 offer writer +// requirements. +func ValidateOfferWrite(o *Offer) error { + // Writer MUST NOT set TLV fields outside allowed ranges. This check + // catches a decoded-then-mutated offer: a freshly-built struct has no + // decodedTLVs (Decode is the only writer of that field). The typed + // field set already excludes out-of-range types by construction, so a + // freshly-built offer cannot violate the range rule in the first place. + for _, t := range sortedTypes(o.decodedTLVs) { + if !offerAllowedRange(t) { + return fmt.Errorf("%w: type %d", + ErrOutOfRangeType, t) + } + } + + // offer_amount requires offer_description. + if o.OfferAmount.IsSome() && !o.OfferDescription.IsSome() { + return ErrMissingDescription + } + + // offer_amount, if set, must be strictly greater than zero. + if err := checkAmountPositive(o.OfferAmount); err != nil { + return err + } + + // offer_currency requires offer_amount. + if o.OfferCurrency.IsSome() && !o.OfferAmount.IsSome() { + return ErrCurrencyWithoutAmount + } + + // Without offer_paths, MUST set offer_issuer_id. + if !o.OfferPaths.IsSome() && !o.OfferIssuerID.IsSome() { + return ErrNoIssuerIdentity + } + + // Defense in depth: writer-side mirrors of reader rejections for + // present-but-empty offer_chains and offer_paths. + var chainsEmpty bool + o.OfferChains.WhenSome( + func(r tlv.RecordT[tlv.TlvType2, ChainsRecord]) { + if len(r.Val.Chains) == 0 { + chainsEmpty = true + } + }, + ) + if chainsEmpty { + return ErrEmptyChains + } + + if err := checkBlindedPaths(o.OfferPaths); err != nil { + return err + } + + // Defense in depth: writer-side mirrors of the reader UTF-8 checks + // for offer_currency, offer_description, and offer_issuer. + if err := checkUTF8(o.OfferCurrency, "offer_currency"); err != nil { + return err + } + + if err := checkUTF8( + o.OfferDescription, "offer_description", + ); err != nil { + return err + } + + if err := checkUTF8(o.OfferIssuer, "offer_issuer"); err != nil { + return err + } + + if err := checkISO4217(o.OfferCurrency); err != nil { + return err + } + + return nil +} + +// checkISO4217 verifies that offer_currency, if set, parses as an ISO 4217 +// code. The upstream parser is case-insensitive and rejects both malformed and +// unrecognised codes. +func checkISO4217[T tlv.TlvType](opt tlv.OptionalRecordT[T, tlv.Blob]) error { + return fn.MapOptionZ(opt.ValOpt(), func(data tlv.Blob) error { + if _, err := currency.ParseISO(string(data)); err != nil { + return fmt.Errorf("%w: %w", ErrInvalidCurrency, err) + } + + return nil + }) +} + +// checkFeatures rejects any unknown even (must-understand) feature bit. +func checkFeatures[T tlv.TlvType]( + opt tlv.OptionalRecordT[T, lnwire.RawFeatureVector]) error { + + return fn.MapOptionZ( + opt.ValOpt(), + func(fv lnwire.RawFeatureVector) error { + // nil catalogue: BOLT 12 defines no feature bits yet, + // so every set even bit is "unknown". Swap in a + // Bolt12Features map once the spec assigns bits. + wrapped := lnwire.NewFeatureVector(&fv, nil) + unknown := wrapped.UnknownRequiredFeatures() + if len(unknown) == 0 { + return nil + } + + // Sort for deterministic errors. + slices.Sort(unknown) + return fmt.Errorf("%w: bit %d", + ErrUnknownEvenFeature, unknown[0]) + }, + ) +} + +// checkBlindedPaths walks each path in a blinded paths field and rejects empty +// Paths slices and paths with zero hops. +func checkBlindedPaths[T tlv.TlvType]( + opt tlv.OptionalRecordT[T, lnwire.BlindedPaths]) error { + + return fn.MapOptionZ( + opt.ValOpt(), + func(paths lnwire.BlindedPaths) error { + if len(paths.Paths) == 0 { + return ErrEmptyBlindedPaths + } + + for i, p := range paths.Paths { + if len(p.Hops) == 0 { + return fmt.Errorf("%w: path %d", + lnwire.ErrEmptyBlindedPath, i) + } + } + + return nil + }, + ) +} + +// checkAmountPositive rejects an offer_amount that is present but zero. +func checkAmountPositive[T tlv.TlvType]( + opt tlv.OptionalRecordT[T, TUint64]) error { + + return fn.MapOptionZ(opt.ValOpt(), func(v TUint64) error { + if v == 0 { + return ErrZeroAmount + } + + return nil + }) +} + +// checkUTF8 validates that a blob field contains valid UTF-8. +func checkUTF8[T tlv.TlvType](opt tlv.OptionalRecordT[T, tlv.Blob], + name string) error { + + return fn.MapOptionZ(opt.ValOpt(), func(data tlv.Blob) error { + if !utf8.Valid(data) { + return fmt.Errorf("%w: %s", ErrInvalidUTF8, name) + } + + return nil + }) +} diff --git a/bolt12/validate_test.go b/bolt12/validate_test.go new file mode 100644 index 00000000000..b1f6056839a --- /dev/null +++ b/bolt12/validate_test.go @@ -0,0 +1,402 @@ +package bolt12 + +import ( + "testing" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// validBobOffer is the spec-minimal happy-path offer that each table row +// mutates to isolate the rule under test. +func validBobOffer(t *testing.T) *Offer { + t.Helper() + + _, pub := bobKey() + + return &Offer{ + OfferIssuerID: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType22](pub), + ), + } +} + +// TestValidateOfferWrite pins the BOLT 12 writer-side MUSTs that the codec can +// enforce. +func TestValidateOfferWrite(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mutate func(*Offer) + wantErr error + }{ + { + name: "happy path with issuer_id only", + mutate: func(*Offer) {}, + wantErr: nil, + }, + { + name: "amount without description", + mutate: func(o *Offer) { + o.OfferAmount = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8]( + TUint64(1000), + ), + ) + }, + wantErr: ErrMissingDescription, + }, + { + name: "currency without amount", + mutate: func(o *Offer) { + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("USD")), + ) + }, + wantErr: ErrCurrencyWithoutAmount, + }, + { + name: "zero amount with description", + mutate: func(o *Offer) { + o.OfferAmount = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8]( + TUint64(0), + ), + ) + o.OfferDescription = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType10, + tlv.Blob](tlv.Blob("a tip")), + ) + }, + wantErr: ErrZeroAmount, + }, + { + name: "no issuer or paths", + mutate: func(o *Offer) { + o.OfferIssuerID = tlv.OptionalRecordT[ + tlv.TlvType22, *btcec.PublicKey]{} + }, + wantErr: ErrNoIssuerIdentity, + }, + { + name: "empty offer_chains", + mutate: func(o *Offer) { + o.OfferChains = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2]( + ChainsRecord{Chains: nil}, + ), + ) + }, + wantErr: ErrEmptyChains, + }, + { + name: "currency wrong length", + mutate: func(o *Offer) { + addAmountAndDescription(o) + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("US")), + ) + }, + wantErr: ErrInvalidCurrency, + }, + { + name: "currency unknown ISO 4217 code", + mutate: func(o *Offer) { + addAmountAndDescription(o) + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("ZZZ")), + ) + }, + wantErr: ErrInvalidCurrency, + }, + { + // Pins the docstring claim that ValidateOfferWrite's + // offerAllowedRange loop exists to catch a + // decoded-then-mutated offer with an out-of-range TLV + // resurfacing via decodedTLVs. + name: "out-of-range TLV in decoded extras", + mutate: func(o *Offer) { + o.decodedTLVs = tlv.TypeMap{200: nil} + }, + wantErr: ErrOutOfRangeType, + }, + { + name: "empty blinded paths list", + mutate: func(o *Offer) { + o.OfferPaths = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType16]( + lnwire.BlindedPaths{Paths: nil}, + ), + ) + }, + wantErr: ErrEmptyBlindedPaths, + }, + { + name: "blinded path with zero hops", + mutate: func(o *Offer) { + _, intro := aliceKey() + _, blinding := bobKey() + pk := lnwire.PubkeyIntro{Pubkey: intro} + o.OfferPaths = tlv.SomeRecordT( + //nolint:ll + tlv.NewRecordT[tlv.TlvType16]( + lnwire.BlindedPaths{ + Paths: []lnwire.BlindedPath{{ + IntroductionNode: pk, + BlindingPoint: blinding, + Hops: nil, + }}, + }, + ), + ) + }, + wantErr: lnwire.ErrEmptyBlindedPath, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + o := validBobOffer(t) + tc.mutate(o) + + err := ValidateOfferWrite(o) + if tc.wantErr == nil { + require.NoError(t, err) + return + } + require.ErrorIs(t, err, tc.wantErr) + }) + } +} + +// TestValidateOfferRead pins the BOLT 12 reader-side MUSTs so a malformed or +// unsafe offer is rejected before any invoice request reaches the wire. +func TestValidateOfferRead(t *testing.T) { + t.Parallel() + + now := time.Unix(1_700_000_000, 0) + + var nonBitcoin [32]byte + nonBitcoin[0] = 0x01 + + tests := []struct { + name string + mutate func(*Offer) + activeChain [32]byte + wantErr error + }{ + { + name: "happy path on bitcoin mainnet", + mutate: func(*Offer) {}, + activeChain: bitcoinMainnetGenesisHash, + }, + { + name: "out-of-range TLV in decoded extras", + mutate: func(o *Offer) { + o.decodedTLVs = tlv.TypeMap{200: nil} + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrOutOfRangeType, + }, + { + name: "unknown even TLV type in range rejected", + mutate: func(o *Offer) { + o.decodedTLVs = tlv.TypeMap{24: nil} + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrUnknownEvenType, + }, + { + name: "unknown even feature bit rejected", + mutate: func(o *Offer) { + o.OfferFeatures = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType12]( + *lnwire.NewRawFeatureVector(0), + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrUnknownEvenFeature, + }, + { + name: "unknown odd feature bit ignored", + mutate: func(o *Offer) { + o.OfferFeatures = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType12]( + *lnwire.NewRawFeatureVector(1), + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: nil, + }, + { + name: "non-bitcoin chain rejected when " + + "offer_chains absent", + mutate: func(*Offer) {}, + activeChain: nonBitcoin, + wantErr: ErrUnsupportedChain, + }, + { + name: "explicit chain list missing active chain", + mutate: func(o *Offer) { + var c [32]byte + c[0] = 0xaa + o.OfferChains = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2]( + ChainsRecord{ + Chains: [][32]byte{c}, + }, + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrUnsupportedChain, + }, + { + name: "empty offer_chains list", + mutate: func(o *Offer) { + o.OfferChains = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2]( + ChainsRecord{Chains: nil}, + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrEmptyChains, + }, + { + name: "amount without description", + mutate: func(o *Offer) { + o.OfferAmount = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8]( + TUint64(1000), + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrMissingDescription, + }, + { + name: "currency without amount", + mutate: func(o *Offer) { + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("USD")), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrCurrencyWithoutAmount, + }, + { + name: "zero amount with description", + mutate: func(o *Offer) { + o.OfferAmount = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8]( + TUint64(0), + ), + ) + o.OfferDescription = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType10, + tlv.Blob](tlv.Blob("a tip")), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrZeroAmount, + }, + { + name: "missing issuer and paths", + mutate: func(o *Offer) { + o.OfferIssuerID = tlv.OptionalRecordT[ + tlv.TlvType22, *btcec.PublicKey]{} + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrNoIssuerIdentity, + }, + { + name: "blinded path with zero hops", + mutate: func(o *Offer) { + _, intro := aliceKey() + _, blinding := bobKey() + pk := lnwire.PubkeyIntro{Pubkey: intro} + o.OfferPaths = tlv.SomeRecordT( + //nolint:ll + tlv.NewRecordT[tlv.TlvType16]( + lnwire.BlindedPaths{ + Paths: []lnwire.BlindedPath{{ + IntroductionNode: pk, + BlindingPoint: blinding, + Hops: nil, + }}, + }, + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: lnwire.ErrEmptyBlindedPath, + }, + { + name: "expired offer", + mutate: func(o *Offer) { + expiry := uint64(now.Unix()) - 1 + o.OfferAbsoluteExpiry = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType14]( + TUint64(expiry), + ), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrOfferExpired, + }, + { + name: "currency wrong length", + mutate: func(o *Offer) { + addAmountAndDescription(o) + o.OfferCurrency = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, + tlv.Blob](tlv.Blob("US")), + ) + }, + activeChain: bitcoinMainnetGenesisHash, + wantErr: ErrInvalidCurrency, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + o := validBobOffer(t) + tc.mutate(o) + + err := ValidateOfferRead(o, now, tc.activeChain) + if tc.wantErr == nil { + require.NoError(t, err) + return + } + require.ErrorIs(t, err, tc.wantErr) + }) + } +} + +// addAmountAndDescription satisfies the dependency rules so currency-shape rows +// are not short-circuited before the ISO 4217 check runs. +func addAmountAndDescription(o *Offer) { + o.OfferAmount = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType8](TUint64(1000)), + ) + o.OfferDescription = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType10, tlv.Blob]( + tlv.Blob("a tip"), + ), + ) +} diff --git a/go.mod b/go.mod index af03897854d..99309e397cf 100644 --- a/go.mod +++ b/go.mod @@ -178,7 +178,7 @@ require ( golang.org/x/mod v0.30.0 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.32.0 // indirect + golang.org/x/text v0.32.0 golang.org/x/tools v0.39.0 // indirect google.golang.org/genproto v0.0.0-20231016165738-49dd2c1f3d0b // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect From ee2aa98a8f29421ef70a5cdad54240069f88a903 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Mon, 18 May 2026 11:40:56 +0200 Subject: [PATCH 08/11] docs: update release notes --- bolt12/validate.go | 1 + docs/release-notes/release-notes-0.22.0.md | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/bolt12/validate.go b/bolt12/validate.go index 2ed93196e22..e72535ad8ee 100644 --- a/bolt12/validate.go +++ b/bolt12/validate.go @@ -340,6 +340,7 @@ func checkFeatures[T tlv.TlvType]( // Sort for deterministic errors. slices.Sort(unknown) + return fmt.Errorf("%w: bit %d", ErrUnknownEvenFeature, unknown[0]) }, diff --git a/docs/release-notes/release-notes-0.22.0.md b/docs/release-notes/release-notes-0.22.0.md index fb31bad5662..454883371c6 100644 --- a/docs/release-notes/release-notes-0.22.0.md +++ b/docs/release-notes/release-notes-0.22.0.md @@ -59,6 +59,11 @@ later in the reservation flow as a funder-balance-dust error; they now surface a clearer, spec-aligned error string up front. +* [Initial BOLT 12 Offer codec](https://github.com/lightningnetwork/lnd/pull/10789): + add a new `bolt12/` package with the BOLT 12 `offer` TLV codec and full + reader/writer validation, plus a typed `lnwire.BlindedPath` introduction-node + codec shared by HTLC routing and onion messaging. + ## Testing ## Database @@ -69,5 +74,6 @@ # Contributors (Alphabetical Order) +* bitromortac * Boris Nagaev * Erick Cestari From dc5ebeea746c80fef2b10c629799c846e7f527c7 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Fri, 22 May 2026 09:07:10 +0200 Subject: [PATCH 09/11] lnrpc: document reply_path verbatim passthrough on OnionMessageUpdate Document that the introduction_node field in an OnionMessageUpdate's reply_path is passed through verbatim from the wire, potentially carrying either a 33-byte pubkey or a 9-byte sciddir form. Subscribers wishing to reply must resolve sciddir forms against their local channel graph. The SubscribeOnionMessages bridge is refactored to use a new marshallBlindedPath helper, ensuring a nil reply path remains nil in the RPC response rather than being emitted as an empty struct. --- lnrpc/lightning.pb.go | 7 +++++- lnrpc/lightning.proto | 11 ++++++-- lnrpc/lightning.swagger.json | 2 +- rpcserver.go | 49 ++++++++++++++++++++++-------------- 4 files changed, 46 insertions(+), 23 deletions(-) diff --git a/lnrpc/lightning.pb.go b/lnrpc/lightning.pb.go index 59d01914159..eff66a4cbed 100644 --- a/lnrpc/lightning.pb.go +++ b/lnrpc/lightning.pb.go @@ -1819,7 +1819,12 @@ type OnionMessageUpdate struct { // along its designated path. Onion []byte `protobuf:"bytes,3,opt,name=onion,proto3" json:"onion,omitempty"` // reply_path is the blinded path that should be used when replying to a - // received message. + // received message. The introduction_node field is passed through verbatim + // from the wire. It may carry either the 33-byte SEC1 compressed pubkey + // form or the 9-byte sciddir form. The sciddir form consists of a 1-byte + // direction selector (0x00 or 0x01) followed by an 8-byte short channel ID. + // Subscribers that intend to reply resolve the sciddir form against their + // local channel graph. ReplyPath *BlindedPath `protobuf:"bytes,4,opt,name=reply_path,json=replyPath,proto3" json:"reply_path,omitempty"` // encrypted_recipient_data is the encrypted data that contains the // forwarding information for an onion message. It contains either diff --git a/lnrpc/lightning.proto b/lnrpc/lightning.proto index b60f45e0950..20114f45c25 100644 --- a/lnrpc/lightning.proto +++ b/lnrpc/lightning.proto @@ -695,8 +695,15 @@ message OnionMessageUpdate { // along its designated path. bytes onion = 3; - // reply_path is the blinded path that should be used when replying to a - // received message. + /* + reply_path is the blinded path that should be used when replying to a + received message. The introduction_node field is passed through verbatim + from the wire. It may carry either the 33-byte SEC1 compressed pubkey + form or the 9-byte sciddir form. The sciddir form consists of a 1-byte + direction selector (0x00 or 0x01) followed by an 8-byte short channel ID. + Subscribers that intend to reply resolve the sciddir form against their + local channel graph. + */ BlindedPath reply_path = 4; // encrypted_recipient_data is the encrypted data that contains the diff --git a/lnrpc/lightning.swagger.json b/lnrpc/lightning.swagger.json index 0818132ede3..5eba2489104 100644 --- a/lnrpc/lightning.swagger.json +++ b/lnrpc/lightning.swagger.json @@ -6627,7 +6627,7 @@ }, "reply_path": { "$ref": "#/definitions/lnrpcBlindedPath", - "description": "reply_path is the blinded path that should be used when replying to a\nreceived message." + "description": "reply_path is the blinded path that should be used when replying to a\nreceived message. The introduction_node field is passed through verbatim\nfrom the wire. It may carry either the 33-byte SEC1 compressed pubkey\nform or the 9-byte sciddir form. The sciddir form consists of a 1-byte\ndirection selector (0x00 or 0x01) followed by an 8-byte short channel ID.\nSubscribers that intend to reply resolve the sciddir form against their\nlocal channel graph." }, "encrypted_recipient_data": { "type": "string", diff --git a/rpcserver.go b/rpcserver.go index 6cbe86291b5..eddbabe603e 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -9566,25 +9566,8 @@ func (r *rpcServer) SubscribeOnionMessages( "failed type assertion: %T", update) } - bp := &lnrpc.BlindedPath{} - - //nolint:ll - if oMsg.ReplyPath != nil { - // TODO(bolt12): resolve sciddir intros via - // sciddirResolver so this field is uniformly a - // 33-byte pubkey? - bp.IntroductionNode = oMsg.ReplyPath.IntroductionNode.Bytes() - bp.BlindingPoint = oMsg.ReplyPath.BlindingPoint.SerializeCompressed() - - for _, hop := range oMsg.ReplyPath.Hops { - bp.BlindedHops = append( - bp.BlindedHops, &lnrpc.BlindedHop{ - BlindedNode: hop.BlindedNodeID.SerializeCompressed(), - EncryptedData: hop.EncryptedData, - }, - ) - } - } + // Perform a verbatim pass-through of any reply path. + bp := marshallBlindedPath(oMsg.ReplyPath) //nolint:ll err := server.Send(&lnrpc.OnionMessageUpdate{ @@ -9602,6 +9585,34 @@ func (r *rpcServer) SubscribeOnionMessages( } } +// marshallBlindedPath converts a wire-form blinded path into its RPC +// counterpart. If the input is nil, nil is returned. +func marshallBlindedPath(p *lnwire.BlindedPath) *lnrpc.BlindedPath { + if p == nil { + return nil + } + + bp := &lnrpc.BlindedPath{ + // Introduction node may be a short channel id direction. We + // don't convert to a node public key here to safe us from + // making a db query. + IntroductionNode: p.IntroductionNode.Bytes(), + BlindingPoint: p.BlindingPoint.SerializeCompressed(), + } + + for _, hop := range p.Hops { + blindedNode := hop.BlindedNodeID.SerializeCompressed() + bp.BlindedHops = append( + bp.BlindedHops, &lnrpc.BlindedHop{ + BlindedNode: blindedNode, + EncryptedData: hop.EncryptedData, + }, + ) + } + + return bp +} + // ListAliases returns the set of all aliases we have ever allocated along with // their base SCIDs and possibly a separate confirmed SCID in the case of // zero-conf. From 948e0c1006c17add0c4304c6c31a8a8967244f7a Mon Sep 17 00:00:00 2001 From: bitromortac Date: Fri, 22 May 2026 09:07:10 +0200 Subject: [PATCH 10/11] bolt12: add InvoiceRequest message struct with TLV codec The InvoiceRequest struct represents a BOLT 12 payment request as defined in the offers specification. It mirrors relevant fields from the original offer and adds payer-specific fields (metadata, amount, quantity, payer_id, note, reply paths, BIP 353 name) and a Schnorr signature. The struct implements lnwire.PureTLVMessage and supports round-trip serialization via Encode/DecodeInvoiceRequest. getInvoiceRequestOfferChains and checkInvreqQuantity provide spec-mandated defaults and coupling checks between the offer and request layers. --- bolt12/invoice_request.go | 277 +++++++++++++++++++++++++++++++++ bolt12/invoice_request_test.go | 56 +++++++ bolt12/validate.go | 10 ++ 3 files changed, 343 insertions(+) create mode 100644 bolt12/invoice_request.go create mode 100644 bolt12/invoice_request_test.go diff --git a/bolt12/invoice_request.go b/bolt12/invoice_request.go new file mode 100644 index 00000000000..412e9ae28aa --- /dev/null +++ b/bolt12/invoice_request.go @@ -0,0 +1,277 @@ +package bolt12 + +import ( + "bytes" + "fmt" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// InvoiceRequest represents a BOLT 12 invoice_request message. It mirrors +// offer fields from the original offer. It also adds payer-specific +// fields and a Schnorr signature. +type InvoiceRequest struct { + // OfferChains are the chains that the mirrored offer is valid for. + OfferChains tlv.OptionalRecordT[tlv.TlvType2, ChainsRecord] + + // OfferMetadata is the metadata from the mirrored offer. + OfferMetadata tlv.OptionalRecordT[tlv.TlvType4, tlv.Blob] + + // OfferCurrency is the currency from the mirrored offer. + OfferCurrency tlv.OptionalRecordT[tlv.TlvType6, tlv.Blob] + + // OfferAmount is the amount from the mirrored offer. + OfferAmount tlv.OptionalRecordT[tlv.TlvType8, TUint64] + + // OfferDescription is the description from the mirrored offer. + OfferDescription tlv.OptionalRecordT[tlv.TlvType10, tlv.Blob] + + // OfferFeatures are the features required by the mirrored offer. + OfferFeatures tlv.OptionalRecordT[ + tlv.TlvType12, lnwire.RawFeatureVector, + ] + + // OfferAbsoluteExpiry is the absolute expiry from the mirrored offer. + OfferAbsoluteExpiry tlv.OptionalRecordT[tlv.TlvType14, TUint64] + + // OfferPaths are the blinded paths from the mirrored offer. + OfferPaths tlv.OptionalRecordT[tlv.TlvType16, lnwire.BlindedPaths] + + // OfferIssuer is the issuer name from the mirrored offer. + OfferIssuer tlv.OptionalRecordT[tlv.TlvType18, tlv.Blob] + + // OfferQuantityMax is the maximum quantity allowed by the mirrored + // offer. + OfferQuantityMax tlv.OptionalRecordT[tlv.TlvType20, TUint64] + + // OfferIssuerID is the public key of the offer issuer. + OfferIssuerID tlv.OptionalRecordT[tlv.TlvType22, *btcec.PublicKey] + + // InvreqMetadata is a blob of metadata provided by the payer. + InvreqMetadata tlv.OptionalRecordT[tlv.TlvType0, tlv.Blob] + + // InvreqChain is the chain that the payer is using for this request. + InvreqChain tlv.OptionalRecordT[tlv.TlvType80, [32]byte] + + // InvreqAmount is the amount the payer is offering to pay. + InvreqAmount tlv.OptionalRecordT[tlv.TlvType82, TUint64] + + // InvreqFeatures are the features provided by the payer. + InvreqFeatures tlv.OptionalRecordT[ + tlv.TlvType84, lnwire.RawFeatureVector, + ] + + // InvreqQuantity is the quantity of the offer item being requested. + InvreqQuantity tlv.OptionalRecordT[tlv.TlvType86, TUint64] + + // InvreqPayerID is the public key used by the payer to sign the + // request. + InvreqPayerID tlv.OptionalRecordT[tlv.TlvType88, [33]byte] + + // InvreqPayerNote is an optional note from the payer. + InvreqPayerNote tlv.OptionalRecordT[tlv.TlvType89, tlv.Blob] + + // InvreqPaths are the blinded paths the payer wants the invoice to be + // sent to. + InvreqPaths tlv.OptionalRecordT[tlv.TlvType90, lnwire.BlindedPaths] + + // InvreqBip353Name is the BIP 353 name of the payer. + InvreqBip353Name tlv.OptionalRecordT[tlv.TlvType91, tlv.Blob] + + // Signature is a BIP-340 Schnorr signature covering all fields. + Signature tlv.OptionalRecordT[tlv.TlvType240, [64]byte] + + // decodedTLVs is the canonical TypeMap produced by the typed- + // stream pass that decoded this request. See Offer.decodedTLVs + // for the design rationale. + decodedTLVs tlv.TypeMap +} + +// AllRecords returns the canonical sorted record list for this invoice +// request, merging the typed records with any extra signed-range fields +// that the decoder preserved. +// +// NOTE: this is part of the tlv.PureTLVMessage interface. +func (ir *InvoiceRequest) AllRecords() []tlv.Record { + return allRecordsFromTypeMap( + ir.allRecordProducers(), ir.decodedTLVs, + ) +} + +var _ lnwire.PureTLVMessage = (*InvoiceRequest)(nil) + +// allRecordProducers returns the set of records that are present. +func (ir *InvoiceRequest) allRecordProducers() []tlv.RecordProducer { + var p []tlv.RecordProducer + + lnwire.AddOpt(&p, ir.InvreqMetadata) + lnwire.AddOpt(&p, ir.OfferChains) + lnwire.AddOpt(&p, ir.OfferMetadata) + lnwire.AddOpt(&p, ir.OfferCurrency) + lnwire.AddOpt(&p, ir.OfferAmount) + lnwire.AddOpt(&p, ir.OfferDescription) + lnwire.AddOpt(&p, ir.OfferFeatures) + lnwire.AddOpt(&p, ir.OfferAbsoluteExpiry) + lnwire.AddOpt(&p, ir.OfferPaths) + lnwire.AddOpt(&p, ir.OfferIssuer) + lnwire.AddOpt(&p, ir.OfferQuantityMax) + lnwire.AddOpt(&p, ir.OfferIssuerID) + lnwire.AddOpt(&p, ir.InvreqChain) + lnwire.AddOpt(&p, ir.InvreqAmount) + lnwire.AddOpt(&p, ir.InvreqFeatures) + lnwire.AddOpt(&p, ir.InvreqQuantity) + lnwire.AddOpt(&p, ir.InvreqPayerID) + lnwire.AddOpt(&p, ir.InvreqPayerNote) + lnwire.AddOpt(&p, ir.InvreqPaths) + lnwire.AddOpt(&p, ir.InvreqBip353Name) + lnwire.AddOpt(&p, ir.Signature) + + return p +} + +// Encode serialises the invoice request via the PureTLVMessage shape. +// The per-record canonicalisation is pure: a struct mutated and +// re-encoded reflects the new bytes without any sidecar rehydration +// step. +func (ir *InvoiceRequest) Encode() ([]byte, error) { + var buf bytes.Buffer + if err := lnwire.EncodePureTLVMessage(ir, &buf); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// DecodeInvoiceRequest deserializes an invoice request from a TLV byte +// stream. Decoding is permissive: callers that need spec compliance must run +// ValidateInvoiceRequestRead. +func DecodeInvoiceRequest(data []byte) (*InvoiceRequest, error) { + var ir InvoiceRequest + + invreqMetadata := tlv.ZeroRecordT[tlv.TlvType0, tlv.Blob]() + chains := tlv.ZeroRecordT[tlv.TlvType2, ChainsRecord]() + metadata := tlv.ZeroRecordT[tlv.TlvType4, tlv.Blob]() + currency := tlv.ZeroRecordT[tlv.TlvType6, tlv.Blob]() + amount := tlv.ZeroRecordT[tlv.TlvType8, TUint64]() + desc := tlv.ZeroRecordT[tlv.TlvType10, tlv.Blob]() + features := tlv.ZeroRecordT[tlv.TlvType12, lnwire.RawFeatureVector]() + expiry := tlv.ZeroRecordT[tlv.TlvType14, TUint64]() + paths := tlv.ZeroRecordT[tlv.TlvType16, lnwire.BlindedPaths]() + issuer := tlv.ZeroRecordT[tlv.TlvType18, tlv.Blob]() + qtyMax := tlv.ZeroRecordT[tlv.TlvType20, TUint64]() + issuerID := tlv.ZeroRecordT[tlv.TlvType22, *btcec.PublicKey]() + invreqChain := tlv.ZeroRecordT[tlv.TlvType80, [32]byte]() + invreqAmount := tlv.ZeroRecordT[tlv.TlvType82, TUint64]() + invreqFeatures := tlv.ZeroRecordT[ + tlv.TlvType84, lnwire.RawFeatureVector, + ]() + invreqQty := tlv.ZeroRecordT[tlv.TlvType86, TUint64]() + payerID := tlv.ZeroRecordT[tlv.TlvType88, [33]byte]() + payerNote := tlv.ZeroRecordT[tlv.TlvType89, tlv.Blob]() + invreqPaths := tlv.ZeroRecordT[tlv.TlvType90, lnwire.BlindedPaths]() + bip353 := tlv.ZeroRecordT[tlv.TlvType91, tlv.Blob]() + sig := tlv.ZeroRecordT[tlv.TlvType240, [64]byte]() + + tm, err := decodeStream( + data, + invreqMetadata.Record(), + chains.Record(), + metadata.Record(), + currency.Record(), + amount.Record(), + desc.Record(), + features.Record(), + expiry.Record(), + paths.Record(), + issuer.Record(), + qtyMax.Record(), + issuerID.Record(), + invreqChain.Record(), + invreqAmount.Record(), + invreqFeatures.Record(), + invreqQty.Record(), + payerID.Record(), + payerNote.Record(), + invreqPaths.Record(), + bip353.Record(), + sig.Record(), + ) + if err != nil { + return nil, fmt.Errorf("decode invoice request: %w", err) + } + + lnwire.SetOptFromMap(tm, &ir.InvreqMetadata, invreqMetadata) + lnwire.SetOptFromMap(tm, &ir.OfferChains, chains) + lnwire.SetOptFromMap(tm, &ir.OfferMetadata, metadata) + lnwire.SetOptFromMap(tm, &ir.OfferCurrency, currency) + lnwire.SetOptFromMap(tm, &ir.OfferAmount, amount) + lnwire.SetOptFromMap(tm, &ir.OfferDescription, desc) + lnwire.SetOptFromMap(tm, &ir.OfferFeatures, features) + lnwire.SetOptFromMap(tm, &ir.OfferAbsoluteExpiry, expiry) + lnwire.SetOptFromMap(tm, &ir.OfferPaths, paths) + lnwire.SetOptFromMap(tm, &ir.OfferIssuer, issuer) + lnwire.SetOptFromMap(tm, &ir.OfferQuantityMax, qtyMax) + lnwire.SetOptFromMap(tm, &ir.OfferIssuerID, issuerID) + lnwire.SetOptFromMap(tm, &ir.InvreqChain, invreqChain) + lnwire.SetOptFromMap(tm, &ir.InvreqAmount, invreqAmount) + lnwire.SetOptFromMap(tm, &ir.InvreqFeatures, invreqFeatures) + lnwire.SetOptFromMap(tm, &ir.InvreqQuantity, invreqQty) + lnwire.SetOptFromMap(tm, &ir.InvreqPayerID, payerID) + lnwire.SetOptFromMap(tm, &ir.InvreqPayerNote, payerNote) + lnwire.SetOptFromMap(tm, &ir.InvreqPaths, invreqPaths) + lnwire.SetOptFromMap(tm, &ir.InvreqBip353Name, bip353) + lnwire.SetOptFromMap(tm, &ir.Signature, sig) + + ir.decodedTLVs = tm + + return &ir, nil +} + +// getInvoiceRequestOfferChains returns the chains an invoice request's mirrored +// offer is valid for. If offer_chains is absent, the spec defaults to Bitcoin +// mainnet. +func getInvoiceRequestOfferChains(ir *InvoiceRequest) [][32]byte { + chains := fn.MapOptionZ( + ir.OfferChains.ValOpt(), + func(r ChainsRecord) [][32]byte { return r.Chains }, + ) + + if len(chains) == 0 { + chains = [][32]byte{bitcoinMainnetGenesisHash} + } + + return chains +} + +// checkInvreqQuantity validates the spec coupling between offer_quantity_max +// and invreq_quantity. +func checkInvreqQuantity(ir *InvoiceRequest) error { + if !ir.OfferQuantityMax.IsSome() { + return nil + } + + var qty uint64 + ir.InvreqQuantity.WhenSome( + func(r tlv.RecordT[tlv.TlvType86, TUint64]) { + qty = uint64(r.Val) + }, + ) + if qty == 0 { + return ErrQuantityZero + } + + var maxQty uint64 + ir.OfferQuantityMax.WhenSome( + func(r tlv.RecordT[tlv.TlvType20, TUint64]) { + maxQty = uint64(r.Val) + }, + ) + if maxQty > 0 && qty > maxQty { + return ErrQuantityExceedsMax + } + + return nil +} diff --git a/bolt12/invoice_request_test.go b/bolt12/invoice_request_test.go new file mode 100644 index 00000000000..50f1f6f9702 --- /dev/null +++ b/bolt12/invoice_request_test.go @@ -0,0 +1,56 @@ +package bolt12 + +import ( + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestInvoiceRequestRoundTrip pins encode→decode→re-encode for an +// InvoiceRequest with a representative subset of optional fields. A +// byte-identical re-encode is the invariant that keeps the to-be-signed +// payload stable across the codec boundary. +func TestInvoiceRequestRoundTrip(t *testing.T) { + t.Parallel() + + _, bobPub := bobKey() + var payerID [33]byte + copy(payerID[:], bobPub.SerializeCompressed()) + + metadata := tlv.Blob("payer-metadata") + + ir := &InvoiceRequest{ + InvreqPayerID: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType88, [33]byte]( + payerID, + ), + ), + InvreqMetadata: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType0](metadata), + ), + InvreqAmount: tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType82, TUint64](1000), + ), + Signature: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType240]([64]byte{0x01}), + ), + } + + encoded, err := ir.Encode() + require.NoError(t, err) + require.NotEmpty(t, encoded) + + decoded, err := DecodeInvoiceRequest(encoded) + require.NoError(t, err) + + require.Equal(t, payerID, decoded.InvreqPayerID.UnwrapOrFailV(t)) + require.Equal(t, metadata, decoded.InvreqMetadata.UnwrapOrFailV(t)) + require.Equal( + t, TUint64(1000), decoded.InvreqAmount.UnwrapOrFailV(t), + ) + + reencoded, err := decoded.Encode() + require.NoError(t, err) + require.Equal(t, encoded, reencoded) +} diff --git a/bolt12/validate.go b/bolt12/validate.go index e72535ad8ee..3a967beef97 100644 --- a/bolt12/validate.go +++ b/bolt12/validate.go @@ -82,6 +82,16 @@ var ( // ErrInvalidCurrency is returned when offer_currency is not a valid ISO // 4217 code. ErrInvalidCurrency = errors.New("invalid offer_currency") + + // ErrQuantityZero is returned when invreq_quantity is 0 but + // offer_quantity_max is present. + ErrQuantityZero = errors.New("invreq_quantity is zero") + + // ErrQuantityExceedsMax is returned when invreq_quantity is greater + // than offer_quantity_max. + ErrQuantityExceedsMax = errors.New( + "invreq_quantity exceeds offer_quantity_max", + ) ) // offerAllowedRange returns true if the TLV type falls within the allowed From 38e5fc9256945d66f50a9bba8d53015187e9c19f Mon Sep 17 00:00:00 2001 From: bitromortac Date: Fri, 22 May 2026 09:07:10 +0200 Subject: [PATCH 11/11] bolt12: validate InvoiceRequest per BOLT 12 reader/writer requirements ValidateInvoiceRequestRead and ValidateInvoiceRequestWrite enforce the structural BOLT 12 requirements an invoice request can be checked against on its own. The reader validates incoming requests. The writer catches out-of-range types in decoded-then-mutated requests before they leave the local boundary. Type 240 carries the signature and sits outside the allowed range by spec design. Both validators skip it during the range scan. Two reader MUSTs are deferred. Schnorr signature verification against the merkle root keyed by invreq_payer_id lands with the Invoice message, where the merkle and signing primitives are shared. Offer cross-validation requires an Offer reference the structural validator does not carry, and lands in the bolt12handler layer where both the request and the stored Offer are in scope. --- bolt12/invoice_request.go | 12 +- bolt12/validate.go | 263 +++++++++++++++++++++++++++ bolt12/validate_test.go | 371 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 642 insertions(+), 4 deletions(-) diff --git a/bolt12/invoice_request.go b/bolt12/invoice_request.go index 412e9ae28aa..edf0affa092 100644 --- a/bolt12/invoice_request.go +++ b/bolt12/invoice_request.go @@ -132,11 +132,15 @@ func (ir *InvoiceRequest) allRecordProducers() []tlv.RecordProducer { return p } -// Encode serialises the invoice request via the PureTLVMessage shape. -// The per-record canonicalisation is pure: a struct mutated and -// re-encoded reflects the new bytes without any sidecar rehydration -// step. +// Encode validates the invoice request per writer requirements and +// serialises it via the PureTLVMessage shape. The per-record +// canonicalisation is pure: a struct mutated and re-encoded reflects +// the new bytes without any sidecar rehydration step. func (ir *InvoiceRequest) Encode() ([]byte, error) { + if err := ValidateInvoiceRequestWrite(ir); err != nil { + return nil, fmt.Errorf("validate invoice request: %w", err) + } + var buf bytes.Buffer if err := lnwire.EncodePureTLVMessage(ir, &buf); err != nil { return nil, err diff --git a/bolt12/validate.go b/bolt12/validate.go index 3a967beef97..59d233a2406 100644 --- a/bolt12/validate.go +++ b/bolt12/validate.go @@ -83,6 +83,24 @@ var ( // 4217 code. ErrInvalidCurrency = errors.New("invalid offer_currency") + // ErrMissingPayerID is returned when invreq_payer_id is absent. + ErrMissingPayerID = errors.New("missing invreq_payer_id") + + // ErrMissingMetadata is returned when invreq_metadata is absent. + ErrMissingMetadata = errors.New("missing invreq_metadata") + + // ErrMissingAmount is returned when invoice_amount or invreq_amount is + // absent. + ErrMissingAmount = errors.New("missing amount field") + + // ErrZeroInvoiceAmount is returned by the invoice validators when + // invoice_amount is present but zero. The BOLT 12 spec permits a zero + // amount ("the minimum amount it will accept"). LND rejects it as a + // project policy because a zero-amount HTLC cannot settle past the + // channel-layer dust limit. Callers that want spec-compatible + // behaviour can recover this typed error and proceed manually. + ErrZeroInvoiceAmount = errors.New("invoice_amount is zero") + // ErrQuantityZero is returned when invreq_quantity is 0 but // offer_quantity_max is present. ErrQuantityZero = errors.New("invreq_quantity is zero") @@ -92,9 +110,188 @@ var ( ErrQuantityExceedsMax = errors.New( "invreq_quantity exceeds offer_quantity_max", ) + + // ErrInvalidBip353Name is returned when invreq_bip_353_name is + // structurally malformed or contains a non-alphabet byte. + ErrInvalidBip353Name = errors.New("invalid invreq_bip_353_name") + + // ErrMissingSignature is returned when a wire-form invoice or + // invoice_request is emitted without a populated signature TLV. + // Pre-sign Encode (used to compute the Merkle root) is permitted to run + // without a signature; the bech32 string-codec layer is where the + // signature becomes mandatory. + ErrMissingSignature = errors.New("missing signature") ) +// isKnownInvreqTLVType determines if a TLV type is defined in the +// invoice_request specification. +func isKnownInvreqTLVType(typ tlv.Type) bool { + if isKnownOfferTLVType(typ) { + return true + } + switch typ { + case 0, 80, 82, 84, 86, 88, 89, 90, 91, 240: + return true + default: + return false + } +} + +// ValidateInvoiceRequestWrite ensures an invoice request adheres to the BOLT 12 +// writer requirements. +func ValidateInvoiceRequestWrite(ir *InvoiceRequest) error { + for _, t := range sortedTypes(ir.decodedTLVs) { + if t == 240 { + continue + } + if !invreqAllowedRange(t) { + return fmt.Errorf("%w: type %d", + ErrOutOfRangeType, t) + } + } + + if !ir.InvreqPayerID.IsSome() { + return ErrMissingPayerID + } + if !ir.InvreqMetadata.IsSome() { + return ErrMissingMetadata + } + + // If offer_amount is not present, we MUST set invreq_amount. + if !ir.OfferAmount.IsSome() && !ir.InvreqAmount.IsSome() { + return ErrMissingAmount + } + + if err := checkInvreqQuantity(ir); err != nil { + return err + } + + if err := checkBip353Name(ir.InvreqBip353Name); err != nil { + return err + } + + if err := checkFeatures(ir.InvreqFeatures); err != nil { + return err + } + + err := checkUTF8(ir.InvreqPayerNote, "invreq_payer_note") + if err != nil { + return err + } + + if err := checkBlindedPaths(ir.InvreqPaths); err != nil { + return err + } + + return nil +} + +// invreqAllowedRange determines if the TLV type falls within the allowed +// ranges for invoice request messages. +func invreqAllowedRange(typ tlv.Type) bool { + return typ <= 159 || + (typ >= 1000000000 && typ <= 2999999999) +} + +// ValidateInvoiceRequestRead ensures an invoice request adheres to the BOLT 12 +// reader requirements. +func ValidateInvoiceRequestRead(now time.Time, activeChain [32]byte, + ir *InvoiceRequest) error { + + // Check TLV types are in allowed range and that unknown even types are + // rejected (even = must-understand). Type 240 carries the signature and + // sits outside the allowed range by spec design; skip it here as + // ValidateInvoiceRequestWrite does (validate.go:144). + for _, t := range sortedTypes(ir.decodedTLVs) { + if t == 240 { + continue + } + if !invreqAllowedRange(t) { + return fmt.Errorf("%w: type %d", ErrOutOfRangeType, t) + } + + if !isKnownInvreqTLVType(t) && t%2 == 0 { + return fmt.Errorf("%w: type %d", ErrUnknownEvenType, t) + } + } + + // Check for unknown even feature bits. + if err := checkFeatures(ir.InvreqFeatures); err != nil { + return err + } + + if !ir.InvreqPayerID.IsSome() { + return ErrMissingPayerID + } + + if !ir.InvreqMetadata.IsSome() { + return ErrMissingMetadata + } + + // If offer_amount is not present, we MUST set invreq_amount. + if !ir.OfferAmount.IsSome() && !ir.InvreqAmount.IsSome() { + return ErrMissingAmount + } + + if err := checkInvreqQuantity(ir); err != nil { + return err + } + + if err := checkBip353Name(ir.InvreqBip353Name); err != nil { + return err + } + + // Validate the offer's chain against the active chain. An absent + // offer_chains TLV means "Bitcoin mainnet" per spec. + offerChains := getInvoiceRequestOfferChains(ir) + if !slices.Contains(offerChains, activeChain) { + return ErrUnsupportedChain + } + + // If invreq_chain is present, it must match the active chain. + var chainMismatch bool + ir.InvreqChain.WhenSome(func(r tlv.RecordT[tlv.TlvType80, [32]byte]) { + if r.Val != activeChain { + chainMismatch = true + } + }) + if chainMismatch { + return ErrUnsupportedChain + } + + // Expiry check. + var ( + expiry uint64 + hasExpiry bool + ) + ir.OfferAbsoluteExpiry.WhenSome( + func(r tlv.RecordT[tlv.TlvType14, TUint64]) { + expiry = uint64(r.Val) + hasExpiry = true + }, + ) + if hasExpiry && uint64(now.Unix()) > expiry { + return ErrOfferExpired + } + + err := checkUTF8(ir.InvreqPayerNote, "invreq_payer_note") + if err != nil { + return err + } + + if err := checkBlindedPaths(ir.InvreqPaths); err != nil { + return err + } + + if !ir.Signature.IsSome() { + return ErrMissingSignature + } + + return nil +} + // offerAllowedRange returns true if the TLV type falls within the allowed + // ranges for offer messages: 1-79 and 1000000000-1999999999. func offerAllowedRange(typ tlv.Type) bool { return (typ >= 1 && typ <= 79) || @@ -406,3 +603,69 @@ func checkUTF8[T tlv.TlvType](opt tlv.OptionalRecordT[T, tlv.Blob], return nil }) } + +// checkBip353Name validates the wire layout and alphabet of +// invreq_bip_353_name. Both name and domain MUST contain only DNS-safe +// characters per the BOLT 12 reader and writer requirements. +func checkBip353Name( + opt tlv.OptionalRecordT[tlv.TlvType91, tlv.Blob]) error { + + var data []byte + opt.WhenSome(func(r tlv.RecordT[tlv.TlvType91, tlv.Blob]) { + data = r.Val + }) + if data == nil { + return nil + } + + if len(data) < 1 { + return fmt.Errorf("%w: missing name_len", + ErrInvalidBip353Name) + } + nameLen := int(data[0]) + + domainLenIdx := 1 + nameLen + if domainLenIdx >= len(data) { + return fmt.Errorf("%w: truncated before domain_len", + ErrInvalidBip353Name) + } + + name := data[1:domainLenIdx] + + domainStart := domainLenIdx + 1 + domainLen := int(data[domainLenIdx]) + if domainStart+domainLen != len(data) { + return fmt.Errorf("%w: domain length mismatch", + ErrInvalidBip353Name) + } + domain := data[domainStart:] + + if err := checkBip353Alphabet(name); err != nil { + return fmt.Errorf("%w: name: %w", + ErrInvalidBip353Name, err) + } + if err := checkBip353Alphabet(domain); err != nil { + return fmt.Errorf("%w: domain: %w", + ErrInvalidBip353Name, err) + } + + return nil +} + +// checkBip353Alphabet returns an error when any byte falls outside the BIP 353 +// alphabet. +func checkBip353Alphabet(b []byte) error { + for i, c := range b { + switch { + case c >= '0' && c <= '9': + case c >= 'a' && c <= 'z': + case c >= 'A' && c <= 'Z': + case c == '-' || c == '_' || c == '.': + default: + return fmt.Errorf("byte %d (0x%02x) outside "+ + "alphabet", i, c) + } + } + + return nil +} diff --git a/bolt12/validate_test.go b/bolt12/validate_test.go index b1f6056839a..8ed86a0398d 100644 --- a/bolt12/validate_test.go +++ b/bolt12/validate_test.go @@ -400,3 +400,374 @@ func addAmountAndDescription(o *Offer) { ), ) } + +// validInvoiceRequest is the spec-minimal happy-path invoice request that +// each table row mutates to isolate the rule under test. +func validInvoiceRequest(t *testing.T) *InvoiceRequest { + t.Helper() + + ir := &InvoiceRequest{} + + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + var payerID [33]byte + copy(payerID[:], privKey.PubKey().SerializeCompressed()) + ir.InvreqPayerID = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType88, [33]byte](payerID), + ) + + ir.InvreqMetadata = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType0]( + []byte("metadata"), + ), + ) + + ir.InvreqAmount = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType82, TUint64](1000), + ) + + ir.Signature = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType240]([64]byte{0x01}), + ) + + return ir +} + +// TestValidateInvoiceRequestRead pins the BOLT 12 reader-side MUSTs so a +// malformed or unsafe invoice request is rejected. +func TestValidateInvoiceRequestRead(t *testing.T) { + t.Parallel() + + now := time.Unix(1_700_000_000, 0) + activeChain := bitcoinMainnetGenesisHash + + tests := []struct { + name string + mutate func(*InvoiceRequest) + wantErr error + }{ + { + name: "happy path", + mutate: func(*InvoiceRequest) {}, + }, + { + name: "out-of-range TLV in decoded extras", + mutate: func(ir *InvoiceRequest) { + ir.decodedTLVs = tlv.TypeMap{200: nil} + }, + wantErr: ErrOutOfRangeType, + }, + { + name: "unknown even TLV type in range rejected", + mutate: func(ir *InvoiceRequest) { + ir.decodedTLVs = tlv.TypeMap{158: nil} + }, + wantErr: ErrUnknownEvenType, + }, + { + name: "unknown even feature bit rejected", + mutate: func(ir *InvoiceRequest) { + ir.InvreqFeatures = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType84]( + *lnwire.NewRawFeatureVector(0), + ), + ) + }, + wantErr: ErrUnknownEvenFeature, + }, + { + name: "missing payer_id", + mutate: func(ir *InvoiceRequest) { + ir.InvreqPayerID = tlv.OptionalRecordT[ + tlv.TlvType88, [33]byte]{} + }, + wantErr: ErrMissingPayerID, + }, + { + name: "missing metadata", + mutate: func(ir *InvoiceRequest) { + ir.InvreqMetadata = tlv.OptionalRecordT[ + tlv.TlvType0, tlv.Blob]{} + }, + wantErr: ErrMissingMetadata, + }, + { + name: "missing amount (neither offer nor invreq)", + mutate: func(ir *InvoiceRequest) { + ir.InvreqAmount = tlv.OptionalRecordT[ + tlv.TlvType82, TUint64]{} + }, + wantErr: ErrMissingAmount, + }, + { + name: "unsupported chain in offer_chains", + mutate: func(ir *InvoiceRequest) { + var c [32]byte + c[0] = 0xaa + ir.OfferChains = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType2]( + ChainsRecord{ + Chains: [][32]byte{c}, + }, + ), + ) + }, + wantErr: ErrUnsupportedChain, + }, + { + name: "mismatched invreq_chain", + mutate: func(ir *InvoiceRequest) { + var c [32]byte + c[0] = 0xaa + ir.InvreqChain = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType80]( + c, + ), + ) + }, + wantErr: ErrUnsupportedChain, + }, + { + name: "expired via mirrored offer expiry", + mutate: func(ir *InvoiceRequest) { + expiry := uint64(now.Unix()) - 1 + ir.OfferAbsoluteExpiry = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType14]( + TUint64(expiry), + ), + ) + }, + wantErr: ErrOfferExpired, + }, + { + name: "invalid UTF-8 in payer_note", + mutate: func(ir *InvoiceRequest) { + ir.InvreqPayerNote = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType89]( + []byte{0xff}, + ), + ) + }, + wantErr: ErrInvalidUTF8, + }, + { + name: "missing signature", + mutate: func(ir *InvoiceRequest) { + ir.Signature = tlv.OptionalRecordT[ + tlv.TlvType240, [64]byte]{} + }, + wantErr: ErrMissingSignature, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ir := validInvoiceRequest(t) + tc.mutate(ir) + + err := ValidateInvoiceRequestRead(now, activeChain, ir) + if tc.wantErr == nil { + require.NoError(t, err) + return + } + require.ErrorIs(t, err, tc.wantErr) + }) + } +} + +// TestValidateInvoiceRequestWrite pins the BOLT 12 writer-side MUSTs so a +// malformed or incomplete invoice request is rejected. +func TestValidateInvoiceRequestWrite(t *testing.T) { + t.Parallel() + + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + var payerID [33]byte + copy(payerID[:], privKey.PubKey().SerializeCompressed()) + + tests := []struct { + name string + mutate func(*InvoiceRequest) + wantErr error + }{ + { + name: "missing payer_id", + mutate: func(ir *InvoiceRequest) { + ir.InvreqPayerID = tlv.OptionalRecordT[ + tlv.TlvType88, [33]byte]{} + }, + wantErr: ErrMissingPayerID, + }, + { + name: "missing metadata", + mutate: func(ir *InvoiceRequest) { + ir.InvreqMetadata = tlv.OptionalRecordT[ + tlv.TlvType0, tlv.Blob]{} + }, + wantErr: ErrMissingMetadata, + }, + { + name: "missing amount", + mutate: func(ir *InvoiceRequest) { + ir.InvreqAmount = tlv.OptionalRecordT[ + tlv.TlvType82, TUint64]{} + }, + wantErr: ErrMissingAmount, + }, + { + name: "invalid UTF-8 in payer_note", + mutate: func(ir *InvoiceRequest) { + ir.InvreqPayerNote = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType89]( + []byte{0xff}, + ), + ) + }, + wantErr: ErrInvalidUTF8, + }, + { + name: "empty blinded paths", + mutate: func(ir *InvoiceRequest) { + ir.InvreqPaths = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType90]( + lnwire.BlindedPaths{Paths: nil}, + ), + ) + }, + wantErr: ErrEmptyBlindedPaths, + }, + { + name: "happy path", + mutate: func(*InvoiceRequest) {}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Start with a valid request and mutate it. + ir := &InvoiceRequest{ + InvreqPayerID: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType88]( + payerID, + ), + ), + InvreqMetadata: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType0]( + []byte("metadata"), + ), + ), + InvreqAmount: tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType82, TUint64]( + 1000, + ), + ), + } + + tc.mutate(ir) + + err := ValidateInvoiceRequestWrite(ir) + if tc.wantErr == nil { + require.NoError(t, err) + return + } + require.ErrorIs(t, err, tc.wantErr) + }) + } +} + +// bip353Blob assembles a name+domain pair into the wire layout expected +// by invreq_bip_353_name (TLV 91). +func bip353Blob(name, domain []byte) []byte { + out := make([]byte, 0, 2+len(name)+len(domain)) + out = append(out, byte(len(name))) + out = append(out, name...) + out = append(out, byte(len(domain))) + out = append(out, domain...) + + return out +} + +// TestCheckBip353Name exercises the BIP 353 alphabet and structural +// requirements directly so each rejection path is pinned independently +// of the surrounding invoice-request validators. +func TestCheckBip353Name(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + blob []byte + wantErr bool + }{ + { + name: "happy path with allowed alphabet", + blob: bip353Blob( + []byte("alice.example-1_2"), + []byte("example.com"), + ), + wantErr: false, + }, + { + name: "absent field is no-op", + blob: nil, + wantErr: false, + }, + { + name: "name byte outside alphabet", + blob: bip353Blob( + []byte("alice@bob"), []byte("ex.com"), + ), + wantErr: true, + }, + { + name: "domain byte outside alphabet", + blob: bip353Blob([]byte("alice"), []byte("ex com")), + wantErr: true, + }, + { + name: "name truncated before domain_len", + blob: []byte{0x05, 'a', 'l', 'i'}, + wantErr: true, + }, + { + name: "domain length mismatch", + blob: []byte{0x01, 'a', 0x05, 'b'}, + wantErr: true, + }, + { + name: "control byte rejected in name", + blob: bip353Blob( + []byte{'a', 0x00, 'b'}, []byte("ex"), + ), + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var opt tlv.OptionalRecordT[tlv.TlvType91, tlv.Blob] + if tc.blob != nil { + opt = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType91]( + tc.blob, + ), + ) + } + + err := checkBip353Name(opt) + if tc.wantErr { + require.Error(t, err) + require.ErrorIs(t, err, ErrInvalidBip353Name) + } else { + require.NoError(t, err) + } + }) + } +}