diff --git a/atomic-primops/Data/Atomics.hs b/atomic-primops/Data/Atomics.hs index 85d616d..da222a2 100644 --- a/atomic-primops/Data/Atomics.hs +++ b/atomic-primops/Data/Atomics.hs @@ -13,7 +13,7 @@ module Data.Atomics ( -- * Types for atomic operations - Ticket, peekTicket, -- CASResult(..), + Ticket, peekTicket, peekTicketA, -- CASResult(..), -- * Atomic operations on IORefs readForCAS, casIORef, casIORef2, @@ -51,7 +51,7 @@ import Data.Primitive.ByteArray (MutableByteArray(MutableByteArray)) import Data.Atomics.Internal import Data.IORef -import GHC.IORef hiding (atomicModifyIORef) +import GHC.IORef (IORef (..)) import GHC.STRef import GHC.Exts hiding ((==#)) import qualified GHC.PrimopWrappers as GPW @@ -60,8 +60,6 @@ import GHC.IO (IO(IO)) #ifdef DEBUG_ATOMICS #warning "Activating DEBUG_ATOMICS... NOINLINE's and more" -{-# NOINLINE seal #-} - {-# NOINLINE casIORef #-} {-# NOINLINE casArrayElem2 #-} {-# NOINLINE readArrayElem #-} @@ -122,11 +120,8 @@ casArrayElem2 (MutableArray arr#) (I# i#) old new = IO$ \s1# -> -- | Ordinary processor load instruction (non-atomic, not implying any memory barriers). readArrayElem :: forall a . MutableArray RealWorld a -> Int -> IO (Ticket a) --- readArrayElem = unsafeCoerce# readArray# -readArrayElem (MutableArray arr#) (I# i#) = IO $ \ st -> unsafeCoerce# (fn st) - where - fn :: State# RealWorld -> (# State# RealWorld, a #) - fn = readArray# arr# i# +readArrayElem (MutableArray arr#) (I# i#) = IO $ \ st -> + readArrayElem# arr# i# st -- | Compare and swap on word-sized chunks of a byte-array. For indexing purposes -- the bytearray is treated as an array of words (`Int`s). Note that UNLIKE @@ -147,7 +142,7 @@ casByteArrayInt (MutableByteArray mba#) (I# ix#) (I# old#) (I# new#) = -- case casByteArrayInt# mba# ix# old# new# s1# of -- (# s2#, x#, res #) -> (# s2#, (x# ==# 0#, I# res) #) - let (# s2#, res #) = casIntArray# mba# ix# old# new# s1# in + let !(# s2#, res #) = casIntArray# mba# ix# old# new# s1# in (# s2#, (I# res) #) -- I don't know if a let will mak any difference here... hopefully not. @@ -162,7 +157,7 @@ fetchAddIntArray :: MutableByteArray RealWorld -> Int -- ^ The value to be added -> IO Int -- ^ The value *before* the addition fetchAddIntArray (MutableByteArray mba#) (I# offset#) (I# incr#) = IO $ \ s1# -> - let (# s2#, res #) = fetchAddIntArray# mba# offset# incr# s1# in + let !(# s2#, res #) = fetchAddIntArray# mba# offset# incr# s1# in (# s2#, (I# res) #) @@ -215,7 +210,7 @@ doAtomicRMW :: (MutableByteArray# RealWorld -> Int# -> Int# -> State# RealWorld doAtomicRMW atomicOp# = \(MutableByteArray mba#) (I# offset#) (I# val#) -> IO $ \ s1# -> - let (# s2#, res #) = atomicOp# mba# offset# val# s1# in + let !(# s2#, res #) = atomicOp# mba# offset# val# s1# in (# s2#, (I# res) #) @@ -227,7 +222,7 @@ doAtomicRMW atomicOp# = -- such as in GCC's `__sync_add_and_fetch`. fetchAddByteArrayInt :: MutableByteArray RealWorld -> Int -> Int -> IO Int fetchAddByteArrayInt (MutableByteArray mba#) (I# offset#) (I# incr#) = IO $ \ s1# -> - let (# s2#, res #) = fetchAddIntArray# mba# offset# incr# s1# in + let !(# s2#, res #) = fetchAddIntArray# mba# offset# incr# s1# in (# s2#, (I# (res +# incr#)) #) @@ -301,19 +296,6 @@ casIORef2 (IORef (STRef var)) old new = casMutVar2 var old new -------------------------------------------------------------------------------- --- | A ticket contains or can get the usable Haskell value. --- This function does just that. -{-# NOINLINE peekTicket #-} --- At least this function MUST remain NOINLINE. Issue5 is an example of a bug that --- ensues otherwise. -peekTicket :: Ticket a -> a -peekTicket = unsafeCoerce# - --- Not exposing this for now. Presently the idea is that you must read from the --- mutable data structure itself to get a ticket. -seal :: a -> Ticket a -seal = unsafeCoerce# - -- | Like `readForCAS`, but for `MutVar#`. readMutVarForCAS :: MutVar# RealWorld a -> IO ( Ticket a ) readMutVarForCAS mv = IO$ \ st -> readForCAS# mv st @@ -381,8 +363,8 @@ foreign import ccall unsafe "DUP_write_barrier" writeBarrier -- | A drop-in replacement for `atomicModifyIORef` that -- optimistically attempts to compute the new value and CAS it into -- place without introducing new thunks or locking anything. Note --- that this is more STRICT than its standard counterpart and will only --- place evaluated (WHNF) values in the IORef. +-- that this is more STRICT than its standard counterpart; the value in the +-- 'IORef` will always be forced to WHNF before the function returns. -- -- The upside is that sometimes we see a performance benefit. -- The downside is that this version is speculative -- when it @@ -396,9 +378,10 @@ atomicModifyIORefCAS ref fn = do loop tick effort where effort = 30 :: Int -- TODO: Tune this. - loop _ 0 = atomicModifyIORef ref fn -- Fall back to the regular version. + loop _ 0 = atomicModifyIORef ref (\x -> let r = fn x in fst r `seq` r) -- Fall back to the regular version. loop old tries = do - (new,result) <- evaluate $ fn $ peekTicket old + oldVal <- peekTicketA old + (new,result) <- evaluate $ fn $ oldVal (b,tick) <- casIORef ref old new if b then return result @@ -415,9 +398,10 @@ atomicModifyIORefCAS_ ref fn = do loop tick effort where effort = 30 :: Int -- TODO: Tune this. - loop _ 0 = atomicModifyIORef ref (\ x -> (fn x, ())) + loop _ 0 = atomicModifyIORef ref (\ x -> let r = fn x in r `seq` (fn x, ())) loop old tries = do - new <- evaluate $ fn $ peekTicket old + oldVal <- peekTicketA old + new <- evaluate $ fn $ oldVal (b,val) <- casIORef ref old new if b then return () diff --git a/atomic-primops/Data/Atomics/Counter.hs b/atomic-primops/Data/Atomics/Counter.hs index f9747fa..eb481a8 100644 --- a/atomic-primops/Data/Atomics/Counter.hs +++ b/atomic-primops/Data/Atomics/Counter.hs @@ -111,7 +111,7 @@ peekCTicket !x = x casCounter :: AtomicCounter -> CTicket -> Int -> IO (Bool, CTicket) -- casCounter (AtomicCounter barr) !old !new = casCounter (AtomicCounter mba#) (I# old#) newBox@(I# new#) = IO$ \s1# -> - let (# s2#, res# #) = casIntArray# mba# 0# old# new# s1# in + let !(# s2#, res# #) = casIntArray# mba# 0# old# new# s1# in case res# ==# old# of False -> (# s2#, (False, I# res# ) #) -- Failure True -> (# s2#, (True , newBox ) #) -- Success @@ -130,12 +130,12 @@ casCounter (AtomicCounter mba#) (I# old#) newBox@(I# new#) = IO$ \s1# -> -- loop like CAS. incrCounter :: Int -> AtomicCounter -> IO Int incrCounter (I# incr#) (AtomicCounter mba#) = IO $ \ s1# -> - let (# s2#, res #) = fetchAddIntArray# mba# 0# incr# s1# in + let !(# s2#, res #) = fetchAddIntArray# mba# 0# incr# s1# in (# s2#, (I# (res +# incr#)) #) {-# INLINE incrCounter_ #-} -- | An alternate version for when you don't care about the old value. incrCounter_ :: Int -> AtomicCounter -> IO () incrCounter_ (I# incr#) (AtomicCounter mba#) = IO $ \ s1# -> - let (# s2#, _ #) = fetchAddIntArray# mba# 0# incr# s1# in + let !(# s2#, _ #) = fetchAddIntArray# mba# 0# incr# s1# in (# s2#, () #) diff --git a/atomic-primops/Data/Atomics/Internal.hs b/atomic-primops/Data/Atomics/Internal.hs index 6bd36a2..a1a3e36 100644 --- a/atomic-primops/Data/Atomics/Internal.hs +++ b/atomic-primops/Data/Atomics/Internal.hs @@ -1,5 +1,6 @@ {-# LANGUAGE CPP, TypeSynonymInstances, BangPatterns #-} {-# LANGUAGE ForeignFunctionInterface, GHCForeignImportPrim, MagicHash, UnboxedTuples, UnliftedFFITypes #-} +{-# LANGUAGE ScopedTypeVariables #-} #define CASTFUN @@ -9,24 +10,36 @@ module Data.Atomics.Internal ( casIntArray#, fetchAddIntArray#, readForCAS#, casMutVarTicketed#, casArrayTicketed#, + readArrayElem#, Ticket, - -- * Very unsafe, not to be used - ptrEq + peekTicket, + peekTicketA, + seal, + -- * Very unsafe; for testing only + reallyUnsafeTicketEquality ) where -import GHC.Exts (Int(I#), Any, RealWorld, Int#, State#, MutableArray#, MutVar#, - unsafeCoerce#, reallyUnsafePtrEquality#, - casArray#, casIntArray#, fetchAddIntArray#, readMutVar#, casMutVar#) +import GHC.Exts (Int(I#), RealWorld, Int#, State#, MutableArray#, MutVar#, + reallyUnsafePtrEquality#, readArray#, + casArray#, casIntArray#, fetchAddIntArray#, readMutVar#, casMutVar#, lazy) #ifdef DEBUG_ATOMICS {-# NOINLINE readForCAS# #-} +{-# NOINLINE readArrayElem# #-} {-# NOINLINE casMutVarTicketed# #-} {-# NOINLINE casArrayTicketed# #-} +{-# NOINLINE peekTicket #-} +{-# NOINLINE peekTicketA #-} +{-# NOINLINE seal #-} #else --- {-# INLINE casMutVarTicketed# #-} +{-# INLINE readForCAS# #-} +{-# INLINE readArrayElem# #-} +{-# INLINE casMutVarTicketed# #-} {-# INLINE casArrayTicketed# #-} --- I *think* inlining may be ok here as long as casting happens on the arrow types: +{-# INLINE peekTicket #-} +{-# INLINE peekTicketA #-} +{-# INLINE seal #-} #endif -------------------------------------------------------------------------------- @@ -34,50 +47,74 @@ import GHC.Exts (Int(I#), Any, RealWorld, Int#, State#, MutableArray#, MutVar#, -------------------------------------------------------------------------------- -- | Unsafe, machine-level atomic compare and swap on an element within an Array. -casArrayTicketed# :: MutableArray# RealWorld a -> Int# -> Ticket a -> Ticket a +casArrayTicketed# :: forall a. MutableArray# RealWorld a -> Int# -> Ticket a -> Ticket a -> State# RealWorld -> (# State# RealWorld, Int#, Ticket a #) -- WARNING: cast of a function -- need to verify these are safe or eta expand. -casArrayTicketed# = unsafeCoerce# casArray# +casArrayTicketed# arr i (Ticket old) (Ticket new) s = + case casArray# arr i old new s of + (# s', flag, a #) -> (# s', flag, Ticket a #) --- | When performing compare-and-swaps, the /ticket/ encapsulates proof --- that a thread observed a specific previous value of a mutable --- variable. It is provided in lieu of the "old" value to --- compare-and-swap. +-- | Ordinary processor load instruction (non-atomic, not implying any memory barriers). +readArrayElem# :: forall a . MutableArray# RealWorld a -> Int# + -> State# RealWorld -> (# State# RealWorld, Ticket a #) +readArrayElem# arr i s = + case readArray# arr i s of + (# s', a #) -> (# s', Ticket a #) + +-- | When performing compare-and-swaps, the /ticket/ encapsulates proof that a +-- thread observed a specific previous value of a mutable variable. It is +-- provided in lieu of the "old" value to compare-and-swap. -- -- Design note: `Ticket`s exist to hide objects from the GHC compiler, which --- can normally perform many optimizations that change pointer equality. A Ticket, --- on the other hand, is a first-class object that can be handled by the user, --- but will not have its pointer identity changed by compiler optimizations --- (but will of course, change addresses during garbage collection). -newtype Ticket a = Ticket Any --- If we allow tickets to be a pointer type, then the garbage collector will update --- the pointer when the object moves. +-- can normally perform many optimizations that change pointer equality. A +-- Ticket, on the other hand, is a first-class object that can be handled by +-- the user, but will not have its pointer identity changed by compiler +-- optimizations (but will of course, change addresses during garbage +-- collection). +data Ticket a = Ticket a +-- If we allow tickets to be a pointer type, then the garbage collector will +-- update the pointer when the object moves. + +-- | Wrap up a Haskell value in a ticket. This is not exposed "publicly" for +-- now. Presently the idea is that you must read from the mutable data +-- structure itself to get a ticket. +seal :: a -> Ticket a +seal = Ticket + +-- | Extract a usable Haskell value from a ticket. In many cases, it is better +-- to use 'peekTicketA', to ensure the ticket is unboxed. @peekTicket@ +-- works fine in strict contexts, however. +peekTicket :: Ticket a -> a +-- We use 'lazy' to guarantee that GHC's strictness analysis won't +-- force the ticket contents too early if it sees that the result of +-- `peekTicket` is eventually forced. +peekTicket (Ticket a) = lazy a + +-- | Extract a usable Haskell value from a ticket. +peekTicketA :: Applicative f => Ticket a -> f a +-- We use 'lazy' to guarantee that GHC's strictness analysis won't +-- force the ticket contents too early if it sees that the result of +-- `peekTicket#` is eventually forced. +peekTicketA (Ticket a) = pure (lazy a) instance Show (Ticket a) where show _ = "" -{-# NOINLINE ptrEq #-} -ptrEq :: a -> a -> Bool -ptrEq !x !y = I# (reallyUnsafePtrEquality# x y) == 1 - -instance Eq (Ticket a) where - (==) = ptrEq +-- | Check whether the contents of two tickets are the +-- same pointer. This is used only for testing. +reallyUnsafeTicketEquality :: Ticket a -> Ticket a -> Bool +reallyUnsafeTicketEquality (Ticket x) (Ticket y) = I# (reallyUnsafePtrEquality# x y) == 1 -------------------------------------------------------------------------------- -readForCAS# :: MutVar# RealWorld a -> +readForCAS# :: forall a. MutVar# RealWorld a -> State# RealWorld -> (# State# RealWorld, Ticket a #) --- WARNING: cast of a function -- need to verify these are safe or eta expand: -#ifdef CASTFUN -readForCAS# = unsafeCoerce# readMutVar# -#else -readForCAS# mv rw = - case readMutVar# mv rw of - (# rw', a #) -> (# rw', unsafeCoerce# a #) -#endif +readForCAS# ref s = case readMutVar# ref s of + (# s', a #) -> (# s', Ticket a #) -casMutVarTicketed# :: MutVar# RealWorld a -> Ticket a -> Ticket a -> +casMutVarTicketed# :: forall a. MutVar# RealWorld a -> Ticket a -> Ticket a -> State# RealWorld -> (# State# RealWorld, Int#, Ticket a #) --- WARNING: cast of a function -- need to verify these are safe or eta expand: -casMutVarTicketed# = unsafeCoerce# casMutVar# +casMutVarTicketed# ref (Ticket old) (Ticket new) s = + case casMutVar# ref old new s of + (# s', flag, a #) -> (# s', flag, Ticket a #) diff --git a/atomic-primops/testing/Test.hs b/atomic-primops/testing/Test.hs index cc95cae..b6472e8 100644 --- a/atomic-primops/testing/Test.hs +++ b/atomic-primops/testing/Test.hs @@ -32,6 +32,7 @@ import System.Mem (performGC) ---------------------------------------- import Data.Atomics as A +import Data.Atomics.Internal (reallyUnsafeTicketEquality) import qualified Issue28 @@ -158,7 +159,7 @@ test_random_array_comm threads size iters = do tick0 <- A.readArrayElem arr 0 for_ 1 size $ \ i -> do t2 <- A.readArrayElem arr i - assertEqual "All initial Nothings in the array should be ticket-equal:" tick0 t2 + assertBool "All initial Nothings in the array should be ticket-equal:" (reallyUnsafeTicketEquality tick0 t2) ls <- forkJoin threads $ \_tid -> do localAcc <- newIORef 0