From 998989a4f9b56310d52f89d067f187b429535a06 Mon Sep 17 00:00:00 2001 From: Alexey Kuleshevich Date: Fri, 27 Dec 2024 16:11:33 -0700 Subject: [PATCH] Add an instance for the new `SeedGen` type class --- System/Random/MWC.hs | 37 +++++++++++++++++++++++++++++++++++-- mwc-random.cabal | 4 +++- tests/props.hs | 2 +- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/System/Random/MWC.hs b/System/Random/MWC.hs index 134bd02..b35b769 100644 --- a/System/Random/MWC.hs +++ b/System/Random/MWC.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE BangPatterns, CPP, DeriveDataTypeable, FlexibleContexts, +{-# LANGUAGE BangPatterns, CPP, DataKinds, DeriveDataTypeable, FlexibleContexts, FlexibleInstances, MultiParamTypeClasses, MagicHash, Rank2Types, ScopedTypeVariables, TypeFamilies, UnboxedTuples, TypeOperators #-} @@ -164,12 +164,14 @@ import Control.Monad.ST (ST,runST) import Data.Bits ((.&.), (.|.), shiftL, shiftR, xor) import Data.Int (Int8, Int16, Int32, Int64) import Data.IORef (IORef, atomicModifyIORef, newIORef) +import Data.Maybe (fromMaybe) import Data.Typeable (Typeable) import Data.Vector.Generic (Vector) import Data.Word import Data.Kind import qualified Data.Vector.Generic as G import qualified Data.Vector.Generic.Mutable as GM +import qualified Data.Vector.Primitive as P import qualified Data.Vector.Unboxed as I import qualified Data.Vector.Unboxed.Mutable as M import System.IO (hPutStrLn, stderr) @@ -177,6 +179,11 @@ import System.IO.Unsafe (unsafePerformIO) import qualified Control.Exception as E import System.Random.MWC.SeedSource import qualified System.Random.Stateful as Random +#if MIN_VERSION_random(1,3,0) +import qualified Data.Primitive.ByteArray as Primitive +import qualified Data.Array.Byte as Data +#endif + -- | NOTE: Consider use of more principled type classes -- 'Random.Uniform' and 'Random.UniformRange' instead. @@ -486,6 +493,32 @@ instance PrimMonad m => Random.ThawedGen Seed m where #endif thawGen = restore +#if MIN_VERSION_random(1,3,0) +instance Random.SeedGen Seed where + type SeedSize Seed = 1032 -- == 4 * 258 + fromSeed = toSeed . P.Vector 0 258 . compatFromPrimByteArray . Random.unSeed + toSeed seed = + seedFromVector $ (P.convert :: I.Vector Word32 -> P.Vector Word32) $ fromSeed seed + where + seedFromVector v = + case v of + P.Vector 0 258 ba -> + fromMaybe (error "ByteArray had an unexpected length") $ Random.mkSeed $ compatToPrimByteArray ba + _ | P.length v == 258 -> seedFromVector $ P.force v + _ -> error $ "Impossible: Seed had an unexpected length of: " ++ show (P.length v) + +compatToPrimByteArray :: Data.ByteArray -> Primitive.ByteArray +compatFromPrimByteArray :: Primitive.ByteArray -> Data.ByteArray +#if MIN_VERSION_primitive(0,8,0) +compatToPrimByteArray = id +compatFromPrimByteArray = id +#else +compatToPrimByteArray (Data.ByteArray ba) = Primitive.ByteArray ba +compatFromPrimByteArray (Primitive.ByteArray ba) = Data.ByteArray ba +#endif +#endif + + -- | Convert vector to 'Seed'. It acts similarly to 'initialize' and -- will accept any vector. If you want to pass seed immediately to -- restore you better call initialize directly since following law holds: @@ -582,7 +615,7 @@ nextIndex i = fromIntegral j -- The multiplicator : 0x5BCF5AB2 -- --- Eventhough it is a 'Word64', it is important for the correctness of the proof +-- Even though it is a 'Word64', it is important for the correctness of the proof -- on carry value that it is /not/ greater than maxBound 'Word32'. aa :: Word64 aa = 1540315826 diff --git a/mwc-random.cabal b/mwc-random.cabal index f647fff..9d2c97d 100644 --- a/mwc-random.cabal +++ b/mwc-random.cabal @@ -71,8 +71,10 @@ library , primitive >= 0.6.2 , random >= 1.2 , time - , vector >= 0.7 + , vector >= 0.10.12 , math-functions >= 0.2.1.0 + if impl(ghc < 9.4) + build-depends: data-array-byte ghc-options: -Wall -funbox-strict-fields -fwarn-tabs diff --git a/tests/props.hs b/tests/props.hs index 00ed03d..1871e9a 100644 --- a/tests/props.hs +++ b/tests/props.hs @@ -245,7 +245,7 @@ logProbBinomial n p k k' = fromIntegral k nk' = fromIntegral $ n - k - + cumulativeChi2 :: Int -> Double -> Double cumulativeChi2 (fromIntegral -> ndf) x | x <= 0 = 0