From 9e9a6a7c34fdb446ea6c8783493b096afc0384a9 Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Mon, 18 May 2026 14:54:02 -0400 Subject: [PATCH] Add option to allow decoding CBOR floats into the Go int types. The default behavior remains unchanged (CBOR floats can only be decoded into Go floats). By configuring the option, integral-valued CBOR floats can be decoded into the Go integer and unsigned integer types as long as the destination type is capable of faithfully representing their value. Signed-off-by: Ben Luddy --- decode.go | 84 ++++++++- decode_test.go | 450 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 530 insertions(+), 4 deletions(-) diff --git a/decode.go b/decode.go index 6b9d9e63..95806f40 100644 --- a/decode.go +++ b/decode.go @@ -769,6 +769,25 @@ func (tum TextUnmarshalerMode) valid() bool { return tum >= 0 && tum < maxTextUnmarshalerMode } +// FloatToIntMode specifies whether CBOR floating-point values can be decoded into Go integer types. +type FloatToIntMode int + +const ( + // FloatToIntForbidden disallows decoding CBOR floats into Go integer types. + FloatToIntForbidden FloatToIntMode = iota + + // FloatToIntAllowExact permits decoding CBOR float values into Go integer + // types if the float value can be represented exactly in the destination + // type without loss of precision. NaN and infinity are never permitted. + FloatToIntAllowExact + + maxFloatToIntMode +) + +func (ftim FloatToIntMode) valid() bool { + return ftim >= 0 && ftim < maxFloatToIntMode +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -912,6 +931,10 @@ type DecOptions struct { // implement json.Unmarshaler but do not also implement cbor.Unmarshaler. If nil, decoding // behavior is not influenced by whether or not a type implements json.Unmarshaler. JSONUnmarshalerTranscoder Transcoder + + // FloatToInt specifies whether CBOR floating-point values can be decoded into Go integer + // types. By default, decoding a CBOR float into a Go integer type produces an error. + FloatToInt FloatToIntMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -1128,6 +1151,10 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore return nil, errors.New("cbor: invalid TextUnmarshaler " + strconv.Itoa(int(opts.TextUnmarshaler))) } + if !opts.FloatToInt.valid() { + return nil, errors.New("cbor: invalid FloatToInt " + strconv.Itoa(int(opts.FloatToInt))) + } + dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -1157,6 +1184,7 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore binaryUnmarshaler: opts.BinaryUnmarshaler, textUnmarshaler: opts.TextUnmarshaler, jsonUnmarshalerTranscoder: opts.JSONUnmarshalerTranscoder, + floatToInt: opts.FloatToInt, } return &dm, nil @@ -1238,6 +1266,7 @@ type decMode struct { binaryUnmarshaler BinaryUnmarshalerMode textUnmarshaler TextUnmarshalerMode jsonUnmarshalerTranscoder Transcoder + floatToInt FloatToIntMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -1280,6 +1309,7 @@ func (dm *decMode) DecOptions() DecOptions { BinaryUnmarshaler: dm.binaryUnmarshaler, TextUnmarshaler: dm.textUnmarshaler, JSONUnmarshalerTranscoder: dm.jsonUnmarshalerTranscoder, + FloatToInt: dm.floatToInt, } } @@ -1584,15 +1614,15 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin switch ai { case additionalInformationAsFloat16: f := float64(float16.Frombits(uint16(val)).Float32()) //nolint:gosec - return fillFloat(t, f, v) + return fillFloat(t, f, v, d.dm.floatToInt) case additionalInformationAsFloat32: f := float64(math.Float32frombits(uint32(val))) //nolint:gosec - return fillFloat(t, f, v) + return fillFloat(t, f, v, d.dm.floatToInt) case additionalInformationAsFloat64: f := math.Float64frombits(val) - return fillFloat(t, f, v) + return fillFloat(t, f, v, d.dm.floatToInt) default: // ai <= 24 if d.dm.simpleValues.rejected[SimpleValue(val)] { //nolint:gosec @@ -3144,7 +3174,7 @@ func fillBool(t cborType, val bool, v reflect.Value) error { return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} } -func fillFloat(t cborType, val float64, v reflect.Value) error { +func fillFloat(t cborType, val float64, v reflect.Value, fti FloatToIntMode) error { switch v.Kind() { case reflect.Float32, reflect.Float64: if v.OverflowFloat(val) { @@ -3157,6 +3187,52 @@ func fillFloat(t cborType, val float64, v reflect.Value) error { v.SetFloat(val) return nil } + + if fti != FloatToIntAllowExact { + return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} + } + + // Modf returns (NaN, NaN) for NaN and (+/-Inf, NaN) for +/-Inf, so + // frac != 0 is true in all cases. + i, frac := math.Modf(val) + if frac != 0 { + return &UnmarshalTypeError{ + CBORType: t.String(), + GoType: v.Type().String(), + errorMsg: strconv.FormatFloat(val, 'f', -1, 64) + " is not an integral value", + } + } + + // Range-check before converting to int64/uint64, because the Go spec + // makes float-to-integer conversion implementation-dependent when the + // value is out of range. MinInt64 (-2^63), 2^63, and 2^64 are all + // exact as float64 because they are powers of two. + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n := int64(i) + if i < math.MinInt64 || i >= 1<<63 || v.OverflowInt(n) { + return &UnmarshalTypeError{ + CBORType: t.String(), + GoType: v.Type().String(), + errorMsg: strconv.FormatFloat(val, 'f', -1, 64) + " overflows " + v.Type().String(), + } + } + v.SetInt(n) + return nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n := uint64(i) + if i < 0 || i >= 1<<64 || v.OverflowUint(n) { + return &UnmarshalTypeError{ + CBORType: t.String(), + GoType: v.Type().String(), + errorMsg: strconv.FormatFloat(val, 'f', -1, 64) + " overflows " + v.Type().String(), + } + } + v.SetUint(n) + return nil + } + return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} } diff --git a/decode_test.go b/decode_test.go index 19097efa..457d1c15 100644 --- a/decode_test.go +++ b/decode_test.go @@ -5558,6 +5558,7 @@ func TestDecOptions(t *testing.T) { BinaryUnmarshaler: BinaryUnmarshalerNone, TextUnmarshaler: TextUnmarshalerTextString, JSONUnmarshalerTranscoder: stubTranscoder{}, + FloatToInt: FloatToIntAllowExact, } ov := reflect.ValueOf(opts1) for i := range ov.NumField() { @@ -11213,3 +11214,452 @@ func TestByteStringExpectedFormatErrorDefaultCase(t *testing.T) { t.Errorf("Error() = %q, want %q", got, want) } } + +func TestDecModeInvalidFloatToInt(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{FloatToInt: -1}, + wantErrorMsg: "cbor: invalid FloatToInt -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{FloatToInt: 101}, + wantErrorMsg: "cbor: invalid FloatToInt 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if gotErrorMsg := err.Error(); gotErrorMsg != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", gotErrorMsg, tc.wantErrorMsg) + } + }) + } +} + +func TestFloatToIntDecMode(t *testing.T) { + for _, tc := range []struct { + name string + opt FloatToIntMode + src []byte + dst any + want any + wantErrorMsg string + }{ + // FloatToIntForbidden: float into integer produces an error. + { + name: "FloatToIntForbidden float32 into int", + opt: FloatToIntForbidden, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int", + }, + { + name: "FloatToIntForbidden float32 into uint", + opt: FloatToIntForbidden, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint", + }, + + // FloatToIntForbidden: float into float is still fine. + { + name: "FloatToIntForbidden float32 into float32", + opt: FloatToIntForbidden, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(float32), + want: float32(42.0), + }, + { + name: "FloatToIntForbidden float32 into float64", + opt: FloatToIntForbidden, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(float64), + want: float64(42.0), + }, + + // FloatToIntAllowExact: successful conversions for float32. + { + name: "float32 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int), + want: int(42), + }, + { + name: "float32 into int8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int8), + want: int8(42), + }, + { + name: "float32 into int16", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int16), + want: int16(42), + }, + { + name: "float32 into int32", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int32), + want: int32(42), + }, + { + name: "float32 into int64", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int64), + want: int64(42), + }, + { + name: "float32 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint), + want: uint(42), + }, + { + name: "float32 into uint8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint8), + want: uint8(42), + }, + { + name: "float32 into uint16", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint16), + want: uint16(42), + }, + { + name: "float32 into uint32", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint32), + want: uint32(42), + }, + { + name: "float32 into uint64", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint64), + want: uint64(42), + }, + + // FloatToIntAllowExact: successful conversions for float64. + { + name: "float64 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb3ff0000000000000"), // 1.0 + dst: new(int), + want: int(1), + }, + { + name: "float64 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb3ff0000000000000"), // 1.0 + dst: new(uint), + want: uint(1), + }, + + // FloatToIntAllowExact: successful conversion for float16. + { + name: "float16 NaN into float32 unaffected", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97e00"), // NaN encoded as float16 + dst: new(float32), + }, + + // FloatToIntAllowExact: negative integral float into signed type. + { + name: "negative float32 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fabf800000"), // -1.0 + dst: new(int), + want: int(-1), + }, + { + name: "negative float32 into int8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fac3000000"), // -128.0 + dst: new(int8), + want: int8(-128), + }, + + // FloatToIntAllowExact: zero and negative zero. + { + name: "float32 0.0 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa00000000"), // 0.0 + dst: new(int), + want: int(0), + }, + { + name: "float32 0.0 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa00000000"), // 0.0 + dst: new(uint), + want: uint(0), + }, + { + name: "float64 -0.0 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb8000000000000000"), // -0.0 + dst: new(int), + want: int(0), + }, + { + name: "float64 -0.0 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb8000000000000000"), // -0.0 + dst: new(uint), + want: uint(0), + }, + + // FloatToIntAllowExact: fractional values rejected. + { + name: "fractional float32 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa4048f5c3"), // 3.14 + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (3.140000104904175 is not an integral value)", + }, + { + name: "fractional float64 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb40091eb851eb851f"), // 3.14 + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (3.14 is not an integral value)", + }, + { + name: "fractional float32 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa4048f5c3"), // 3.14 + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (3.140000104904175 is not an integral value)", + }, + + // FloatToIntAllowExact: NaN rejected. + { + name: "NaN into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97e00"), // NaN + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (NaN is not an integral value)", + }, + { + name: "NaN into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97e00"), // NaN + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (NaN is not an integral value)", + }, + + // FloatToIntAllowExact: Infinity rejected. + { + name: "+Inf into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97c00"), // +Inf + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (+Inf is not an integral value)", + }, + { + name: "-Inf into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("f9fc00"), // -Inf + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (-Inf is not an integral value)", + }, + { + name: "+Inf into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97c00"), // +Inf + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (+Inf is not an integral value)", + }, + { + name: "-Inf into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("f9fc00"), // -Inf + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (-Inf is not an integral value)", + }, + + // FloatToIntAllowExact: signed overflow. + { + name: "float32 128.0 overflows int8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa43000000"), // 128.0 + dst: new(int8), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int8 (128 overflows int8)", + }, + { + name: "float32 -129.0 overflows int8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fac3010000"), // -129.0 + dst: new(int8), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int8 (-129 overflows int8)", + }, + { + name: "float64 overflows int64", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb43e0000000000000"), // math.MaxInt64+1 + dst: new(int64), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int64 (9223372036854776000 overflows int64)", + }, + + // FloatToIntAllowExact: unsigned overflow. + { + name: "float32 256.0 overflows uint8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa43800000"), // 256.0 + dst: new(uint8), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint8 (256 overflows uint8)", + }, + { + name: "float64 overflows uint64", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb43f0000000000000"), // math.MaxUint64+1 + dst: new(uint64), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint64 (18446744073709552000 overflows uint64)", + }, + + // FloatToIntAllowExact: negative into unsigned. + { + name: "negative float32 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fabf800000"), // -1.0 + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (-1 overflows uint)", + }, + { + name: "negative float32 into uint8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fabf800000"), // -1.0 + dst: new(uint8), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint8 (-1 overflows uint8)", + }, + + // FloatToIntAllowExact: float destinations unaffected. + { + name: "float32 into float32 still works", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa4048f5c3"), // 3.14 + dst: new(float32), + want: float32(3.14), + }, + { + name: "float64 into float64 still works", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb40091eb851eb851f"), // 3.14 + dst: new(float64), + want: float64(3.14), + }, + { + name: "float64 into any still works", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb40091eb851eb851f"), // 3.14 + dst: new(any), + want: float64(3.14), + }, + + // FloatToIntAllowExact: float16 integral value into integer. + { + name: "float16 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("f93c00"), // float16 1.0 + dst: new(int), + want: int(1), + }, + { + name: "float16 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("f93c00"), // float16 1.0 + dst: new(uint), + want: uint(1), + }, + + // FloatToIntAllowExact: struct field. + { + name: "float32 into struct int field", + opt: FloatToIntAllowExact, + src: mustHexDecode("a16146fa42280000"), // {"F": 42.0} + dst: &struct{ F int }{}, + want: struct{ F int }{F: 42}, + }, + + // FloatToIntAllowExact: map with float key and value into map[int]int. + { + name: "float32 key and value into map[int]int", + opt: FloatToIntAllowExact, + src: mustHexDecode("a1fa3f800000fa40000000"), // {1.0: 2.0} + dst: &map[int]int{}, + want: map[int]int{1: 2}, + }, + + // FloatToIntAllowExact: array of float into []int and [1]int. + { + name: "float32 array into []int", + opt: FloatToIntAllowExact, + src: mustHexDecode("81fa42280000"), // [42.0] + dst: &[]int{}, + want: []int{42}, + }, + { + name: "float32 array into [1]int", + opt: FloatToIntAllowExact, + src: mustHexDecode("81fa42280000"), // [42.0] + dst: &[1]int{}, + want: [1]int{42}, + }, + + // FloatToIntAllowExact: non-numeric destination still errors. + { + name: "float32 into string", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(string), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type string", + }, + { + name: "float32 into bool", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(bool), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type bool", + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := DecOptions{FloatToInt: tc.opt}.DecMode() + if err != nil { + t.Fatal(err) + } + + err = dm.Unmarshal(tc.src, tc.dst) + if err != nil { + if tc.wantErrorMsg == "" { + t.Errorf("Unmarshal(0x%x) returned unexpected error %v", tc.src, err) + } else if gotErrorMsg := err.Error(); gotErrorMsg != tc.wantErrorMsg { + t.Errorf("Unmarshal(0x%x) returned error %q, want %q", tc.src, gotErrorMsg, tc.wantErrorMsg) + } + } else if tc.wantErrorMsg != "" { + t.Errorf("Unmarshal(0x%x) didn't return an error, want %q", tc.src, tc.wantErrorMsg) + } else if tc.want != nil { + got := reflect.ValueOf(tc.dst).Elem().Interface() + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("Unmarshal(0x%x) = %v (%T), want %v (%T)", tc.src, got, got, tc.want, tc.want) + } + } + }) + } +}