diff --git a/PRNG.ml b/PRNG.ml index d66cb40..1ae7622 100644 --- a/PRNG.ml +++ b/PRNG.ml @@ -20,6 +20,7 @@ module type STATE = sig val make_self_init: unit -> t val bool: t -> bool val bit: t -> bool + val uniform: t -> float val float: t -> float -> float val byte: t -> int val bits8: t -> int @@ -49,6 +50,7 @@ module type PURE = sig val make_self_init: unit -> t val bool: t -> bool * t val bit: t -> bool * t + val uniform: t -> float * t val float: float -> t -> float * t val byte: t -> int * t val bits8: t -> int * t @@ -143,14 +145,12 @@ let nativeint = then fun g bound -> Nativeint.of_int32 (int32 g (Nativeint.to_int32 bound)) else fun g bound -> Int64.to_nativeint (int64 g (Int64.of_nativeint bound)) -let float_64 g bound = +let rec uniform g = let b = X.bits64 g in - (Int64.(to_float (shift_right_logical b 11)) *. 0x1.p-53) *. bound + let n = Int64.shift_right_logical b 11 in + if n <> 0L then Int64.to_float n *. 0x1.p-53 else uniform g -let float_32 g bound = - let a = X.bits30 g in - let b = X.bits30 g in - (float a *. 0x1.p-60 +. float b *. 0x1.p-30) *. bound +let float g bound = uniform g *. bound end @@ -218,14 +218,13 @@ let nativeint = (Int64.to_nativeint r, g') end -let float_64 bound g = +let rec uniform g = let (b, g) = X.bits64 g in - ((Int64.(to_float (shift_right_logical b 11)) *. 0x1.p-53) *. bound, g) + let n = Int64.shift_right_logical b 11 in + if n <> 0L then (Int64.to_float n *. 0x1.p-53, g) else uniform g -let float_32 bound g = - let (a, g) = X.bits30 g in - let (b, g) = X.bits30 g in - ((float a *. 0x1.p-60 +. float b *. 0x1.p-30) *. bound, g) +let float bound g = + let (f, g) = uniform g in (f *. bound, g) end @@ -317,8 +316,6 @@ include StateDerived(struct let errorprefix = "PRNG.Splitmix.State." end) -let float = float_64 - let bytes g dst ofs len = if ofs < 0 || len < 0 || ofs > Bytes.length dst - len then invalid_arg "PRNG.State.bytes" @@ -401,8 +398,6 @@ include PureDerived(struct let errorprefix = "PRNG.Splitmix.Pure." end) -let float = float_64 - let split g = let g1 = next g in let g2 = next g1 in @@ -554,8 +549,6 @@ include StateDerived(struct let errorprefix = "PRNG.Chacha.State." end) -let float = if Sys.word_size = 64 then float_64 else float_32 - let bytes g dst ofs len = if ofs < 0 || len < 0 || Bytes.length dst - len > ofs then invalid_arg "PRNG.Chacha.State.bytes"; @@ -676,8 +669,6 @@ include PureDerived(struct let errorprefix = "PRNG.Chacha.Pure." end) -let float = if Sys.word_size = 64 then float_64 else float_32 - let bytes g dst ofs len = if ofs < 0 || len < 0 || Bytes.length dst - len > ofs then invalid_arg "PRNG.Chacha.Pure.bytes"; diff --git a/PRNG.mli b/PRNG.mli index 39bc329..446f678 100644 --- a/PRNG.mli +++ b/PRNG.mli @@ -59,10 +59,19 @@ module type STATE = sig val bit: t -> bool (** Return a Boolean value in [false,true] with 0.5 probability each. *) + val uniform: t -> float + (** Return a floating-point number evenly distributed between 0.0 and 1.0. + 0.0 and 1.0 are never returned. + The result is of the form [n * 2{^-53}], where [n] is a random integer + in [(0, 2{^53})]. *) + val float: t -> float -> float (** [float g x] returns a floating-point number evenly distributed - between 0.0 and [x]. If [x] is negative, negative numbers - between [x] and 0.0 are returned. *) + between 0.0 and [x]. If [x] is negative, negative numbers + between [x] and 0.0 are returned. Implemented as [uniform g *. x]. + Consequently, the values [0.0] and [x] can be returned + (as a result of floating-point rounding), but not if [x] is + [1.0], since [float g 1.0] behaves exactly like [uniform g]. *) val byte: t -> int val bits8: t -> int @@ -90,7 +99,7 @@ module type STATE = sig Note that [int32 Int32.max_int] produces numbers between 0 and [Int32.max_int] excluded. To produce numbers between 0 and [Int32.max_int] included, use - [Int32.logand (bits32 g) Int64.max_int]. *) + [Int32.logand (bits32 g) Int32.max_int]. *) val bits64: t -> int64 (** Return a 64-bit integer evenly distributed between @@ -179,6 +188,7 @@ module type PURE = sig val bool: t -> bool * t val bit: t -> bool * t + val uniform: t -> float * t val float: float -> t -> float * t val byte: t -> int * t