diff --git a/decode.go b/decode.go index 6b9d9e63..95806f40 100644 --- a/decode.go +++ b/decode.go @@ -769,6 +769,25 @@ func (tum TextUnmarshalerMode) valid() bool { return tum >= 0 && tum < maxTextUnmarshalerMode } +// FloatToIntMode specifies whether CBOR floating-point values can be decoded into Go integer types. +type FloatToIntMode int + +const ( + // FloatToIntForbidden disallows decoding CBOR floats into Go integer types. + FloatToIntForbidden FloatToIntMode = iota + + // FloatToIntAllowExact permits decoding CBOR float values into Go integer + // types if the float value can be represented exactly in the destination + // type without loss of precision. NaN and infinity are never permitted. + FloatToIntAllowExact + + maxFloatToIntMode +) + +func (ftim FloatToIntMode) valid() bool { + return ftim >= 0 && ftim < maxFloatToIntMode +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -912,6 +931,10 @@ type DecOptions struct { // implement json.Unmarshaler but do not also implement cbor.Unmarshaler. If nil, decoding // behavior is not influenced by whether or not a type implements json.Unmarshaler. JSONUnmarshalerTranscoder Transcoder + + // FloatToInt specifies whether CBOR floating-point values can be decoded into Go integer + // types. By default, decoding a CBOR float into a Go integer type produces an error. + FloatToInt FloatToIntMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -1128,6 +1151,10 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore return nil, errors.New("cbor: invalid TextUnmarshaler " + strconv.Itoa(int(opts.TextUnmarshaler))) } + if !opts.FloatToInt.valid() { + return nil, errors.New("cbor: invalid FloatToInt " + strconv.Itoa(int(opts.FloatToInt))) + } + dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -1157,6 +1184,7 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore binaryUnmarshaler: opts.BinaryUnmarshaler, textUnmarshaler: opts.TextUnmarshaler, jsonUnmarshalerTranscoder: opts.JSONUnmarshalerTranscoder, + floatToInt: opts.FloatToInt, } return &dm, nil @@ -1238,6 +1266,7 @@ type decMode struct { binaryUnmarshaler BinaryUnmarshalerMode textUnmarshaler TextUnmarshalerMode jsonUnmarshalerTranscoder Transcoder + floatToInt FloatToIntMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -1280,6 +1309,7 @@ func (dm *decMode) DecOptions() DecOptions { BinaryUnmarshaler: dm.binaryUnmarshaler, TextUnmarshaler: dm.textUnmarshaler, JSONUnmarshalerTranscoder: dm.jsonUnmarshalerTranscoder, + FloatToInt: dm.floatToInt, } } @@ -1584,15 +1614,15 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin switch ai { case additionalInformationAsFloat16: f := float64(float16.Frombits(uint16(val)).Float32()) //nolint:gosec - return fillFloat(t, f, v) + return fillFloat(t, f, v, d.dm.floatToInt) case additionalInformationAsFloat32: f := float64(math.Float32frombits(uint32(val))) //nolint:gosec - return fillFloat(t, f, v) + return fillFloat(t, f, v, d.dm.floatToInt) case additionalInformationAsFloat64: f := math.Float64frombits(val) - return fillFloat(t, f, v) + return fillFloat(t, f, v, d.dm.floatToInt) default: // ai <= 24 if d.dm.simpleValues.rejected[SimpleValue(val)] { //nolint:gosec @@ -3144,7 +3174,7 @@ func fillBool(t cborType, val bool, v reflect.Value) error { return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} } -func fillFloat(t cborType, val float64, v reflect.Value) error { +func fillFloat(t cborType, val float64, v reflect.Value, fti FloatToIntMode) error { switch v.Kind() { case reflect.Float32, reflect.Float64: if v.OverflowFloat(val) { @@ -3157,6 +3187,52 @@ func fillFloat(t cborType, val float64, v reflect.Value) error { v.SetFloat(val) return nil } + + if fti != FloatToIntAllowExact { + return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} + } + + // Modf returns (NaN, NaN) for NaN and (+/-Inf, NaN) for +/-Inf, so + // frac != 0 is true in all cases. + i, frac := math.Modf(val) + if frac != 0 { + return &UnmarshalTypeError{ + CBORType: t.String(), + GoType: v.Type().String(), + errorMsg: strconv.FormatFloat(val, 'f', -1, 64) + " is not an integral value", + } + } + + // Range-check before converting to int64/uint64, because the Go spec + // makes float-to-integer conversion implementation-dependent when the + // value is out of range. MinInt64 (-2^63), 2^63, and 2^64 are all + // exact as float64 because they are powers of two. + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n := int64(i) + if i < math.MinInt64 || i >= 1<<63 || v.OverflowInt(n) { + return &UnmarshalTypeError{ + CBORType: t.String(), + GoType: v.Type().String(), + errorMsg: strconv.FormatFloat(val, 'f', -1, 64) + " overflows " + v.Type().String(), + } + } + v.SetInt(n) + return nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n := uint64(i) + if i < 0 || i >= 1<<64 || v.OverflowUint(n) { + return &UnmarshalTypeError{ + CBORType: t.String(), + GoType: v.Type().String(), + errorMsg: strconv.FormatFloat(val, 'f', -1, 64) + " overflows " + v.Type().String(), + } + } + v.SetUint(n) + return nil + } + return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} } diff --git a/decode_test.go b/decode_test.go index 19097efa..457d1c15 100644 --- a/decode_test.go +++ b/decode_test.go @@ -5558,6 +5558,7 @@ func TestDecOptions(t *testing.T) { BinaryUnmarshaler: BinaryUnmarshalerNone, TextUnmarshaler: TextUnmarshalerTextString, JSONUnmarshalerTranscoder: stubTranscoder{}, + FloatToInt: FloatToIntAllowExact, } ov := reflect.ValueOf(opts1) for i := range ov.NumField() { @@ -11213,3 +11214,452 @@ func TestByteStringExpectedFormatErrorDefaultCase(t *testing.T) { t.Errorf("Error() = %q, want %q", got, want) } } + +func TestDecModeInvalidFloatToInt(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{FloatToInt: -1}, + wantErrorMsg: "cbor: invalid FloatToInt -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{FloatToInt: 101}, + wantErrorMsg: "cbor: invalid FloatToInt 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if gotErrorMsg := err.Error(); gotErrorMsg != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", gotErrorMsg, tc.wantErrorMsg) + } + }) + } +} + +func TestFloatToIntDecMode(t *testing.T) { + for _, tc := range []struct { + name string + opt FloatToIntMode + src []byte + dst any + want any + wantErrorMsg string + }{ + // FloatToIntForbidden: float into integer produces an error. + { + name: "FloatToIntForbidden float32 into int", + opt: FloatToIntForbidden, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int", + }, + { + name: "FloatToIntForbidden float32 into uint", + opt: FloatToIntForbidden, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint", + }, + + // FloatToIntForbidden: float into float is still fine. + { + name: "FloatToIntForbidden float32 into float32", + opt: FloatToIntForbidden, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(float32), + want: float32(42.0), + }, + { + name: "FloatToIntForbidden float32 into float64", + opt: FloatToIntForbidden, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(float64), + want: float64(42.0), + }, + + // FloatToIntAllowExact: successful conversions for float32. + { + name: "float32 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int), + want: int(42), + }, + { + name: "float32 into int8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int8), + want: int8(42), + }, + { + name: "float32 into int16", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int16), + want: int16(42), + }, + { + name: "float32 into int32", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int32), + want: int32(42), + }, + { + name: "float32 into int64", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(int64), + want: int64(42), + }, + { + name: "float32 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint), + want: uint(42), + }, + { + name: "float32 into uint8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint8), + want: uint8(42), + }, + { + name: "float32 into uint16", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint16), + want: uint16(42), + }, + { + name: "float32 into uint32", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint32), + want: uint32(42), + }, + { + name: "float32 into uint64", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(uint64), + want: uint64(42), + }, + + // FloatToIntAllowExact: successful conversions for float64. + { + name: "float64 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb3ff0000000000000"), // 1.0 + dst: new(int), + want: int(1), + }, + { + name: "float64 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb3ff0000000000000"), // 1.0 + dst: new(uint), + want: uint(1), + }, + + // FloatToIntAllowExact: successful conversion for float16. + { + name: "float16 NaN into float32 unaffected", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97e00"), // NaN encoded as float16 + dst: new(float32), + }, + + // FloatToIntAllowExact: negative integral float into signed type. + { + name: "negative float32 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fabf800000"), // -1.0 + dst: new(int), + want: int(-1), + }, + { + name: "negative float32 into int8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fac3000000"), // -128.0 + dst: new(int8), + want: int8(-128), + }, + + // FloatToIntAllowExact: zero and negative zero. + { + name: "float32 0.0 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa00000000"), // 0.0 + dst: new(int), + want: int(0), + }, + { + name: "float32 0.0 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa00000000"), // 0.0 + dst: new(uint), + want: uint(0), + }, + { + name: "float64 -0.0 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb8000000000000000"), // -0.0 + dst: new(int), + want: int(0), + }, + { + name: "float64 -0.0 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb8000000000000000"), // -0.0 + dst: new(uint), + want: uint(0), + }, + + // FloatToIntAllowExact: fractional values rejected. + { + name: "fractional float32 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa4048f5c3"), // 3.14 + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (3.140000104904175 is not an integral value)", + }, + { + name: "fractional float64 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb40091eb851eb851f"), // 3.14 + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (3.14 is not an integral value)", + }, + { + name: "fractional float32 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa4048f5c3"), // 3.14 + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (3.140000104904175 is not an integral value)", + }, + + // FloatToIntAllowExact: NaN rejected. + { + name: "NaN into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97e00"), // NaN + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (NaN is not an integral value)", + }, + { + name: "NaN into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97e00"), // NaN + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (NaN is not an integral value)", + }, + + // FloatToIntAllowExact: Infinity rejected. + { + name: "+Inf into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97c00"), // +Inf + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (+Inf is not an integral value)", + }, + { + name: "-Inf into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("f9fc00"), // -Inf + dst: new(int), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int (-Inf is not an integral value)", + }, + { + name: "+Inf into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("f97c00"), // +Inf + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (+Inf is not an integral value)", + }, + { + name: "-Inf into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("f9fc00"), // -Inf + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (-Inf is not an integral value)", + }, + + // FloatToIntAllowExact: signed overflow. + { + name: "float32 128.0 overflows int8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa43000000"), // 128.0 + dst: new(int8), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int8 (128 overflows int8)", + }, + { + name: "float32 -129.0 overflows int8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fac3010000"), // -129.0 + dst: new(int8), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int8 (-129 overflows int8)", + }, + { + name: "float64 overflows int64", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb43e0000000000000"), // math.MaxInt64+1 + dst: new(int64), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type int64 (9223372036854776000 overflows int64)", + }, + + // FloatToIntAllowExact: unsigned overflow. + { + name: "float32 256.0 overflows uint8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa43800000"), // 256.0 + dst: new(uint8), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint8 (256 overflows uint8)", + }, + { + name: "float64 overflows uint64", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb43f0000000000000"), // math.MaxUint64+1 + dst: new(uint64), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint64 (18446744073709552000 overflows uint64)", + }, + + // FloatToIntAllowExact: negative into unsigned. + { + name: "negative float32 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("fabf800000"), // -1.0 + dst: new(uint), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint (-1 overflows uint)", + }, + { + name: "negative float32 into uint8", + opt: FloatToIntAllowExact, + src: mustHexDecode("fabf800000"), // -1.0 + dst: new(uint8), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type uint8 (-1 overflows uint8)", + }, + + // FloatToIntAllowExact: float destinations unaffected. + { + name: "float32 into float32 still works", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa4048f5c3"), // 3.14 + dst: new(float32), + want: float32(3.14), + }, + { + name: "float64 into float64 still works", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb40091eb851eb851f"), // 3.14 + dst: new(float64), + want: float64(3.14), + }, + { + name: "float64 into any still works", + opt: FloatToIntAllowExact, + src: mustHexDecode("fb40091eb851eb851f"), // 3.14 + dst: new(any), + want: float64(3.14), + }, + + // FloatToIntAllowExact: float16 integral value into integer. + { + name: "float16 into int", + opt: FloatToIntAllowExact, + src: mustHexDecode("f93c00"), // float16 1.0 + dst: new(int), + want: int(1), + }, + { + name: "float16 into uint", + opt: FloatToIntAllowExact, + src: mustHexDecode("f93c00"), // float16 1.0 + dst: new(uint), + want: uint(1), + }, + + // FloatToIntAllowExact: struct field. + { + name: "float32 into struct int field", + opt: FloatToIntAllowExact, + src: mustHexDecode("a16146fa42280000"), // {"F": 42.0} + dst: &struct{ F int }{}, + want: struct{ F int }{F: 42}, + }, + + // FloatToIntAllowExact: map with float key and value into map[int]int. + { + name: "float32 key and value into map[int]int", + opt: FloatToIntAllowExact, + src: mustHexDecode("a1fa3f800000fa40000000"), // {1.0: 2.0} + dst: &map[int]int{}, + want: map[int]int{1: 2}, + }, + + // FloatToIntAllowExact: array of float into []int and [1]int. + { + name: "float32 array into []int", + opt: FloatToIntAllowExact, + src: mustHexDecode("81fa42280000"), // [42.0] + dst: &[]int{}, + want: []int{42}, + }, + { + name: "float32 array into [1]int", + opt: FloatToIntAllowExact, + src: mustHexDecode("81fa42280000"), // [42.0] + dst: &[1]int{}, + want: [1]int{42}, + }, + + // FloatToIntAllowExact: non-numeric destination still errors. + { + name: "float32 into string", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(string), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type string", + }, + { + name: "float32 into bool", + opt: FloatToIntAllowExact, + src: mustHexDecode("fa42280000"), // 42.0 + dst: new(bool), + wantErrorMsg: "cbor: cannot unmarshal primitives into Go value of type bool", + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := DecOptions{FloatToInt: tc.opt}.DecMode() + if err != nil { + t.Fatal(err) + } + + err = dm.Unmarshal(tc.src, tc.dst) + if err != nil { + if tc.wantErrorMsg == "" { + t.Errorf("Unmarshal(0x%x) returned unexpected error %v", tc.src, err) + } else if gotErrorMsg := err.Error(); gotErrorMsg != tc.wantErrorMsg { + t.Errorf("Unmarshal(0x%x) returned error %q, want %q", tc.src, gotErrorMsg, tc.wantErrorMsg) + } + } else if tc.wantErrorMsg != "" { + t.Errorf("Unmarshal(0x%x) didn't return an error, want %q", tc.src, tc.wantErrorMsg) + } else if tc.want != nil { + got := reflect.ValueOf(tc.dst).Elem().Interface() + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("Unmarshal(0x%x) = %v (%T), want %v (%T)", tc.src, got, got, tc.want, tc.want) + } + } + }) + } +}