Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tickets reliable #86

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 16 additions & 32 deletions atomic-primops/Data/Atomics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
module Data.Atomics
(
-- * Types for atomic operations
Ticket, peekTicket, -- CASResult(..),
Ticket, peekTicket, peekTicketA, -- CASResult(..),

-- * Atomic operations on IORefs
readForCAS, casIORef, casIORef2,
Expand Down Expand Up @@ -51,7 +51,7 @@ import Data.Primitive.ByteArray (MutableByteArray(MutableByteArray))
import Data.Atomics.Internal

import Data.IORef
import GHC.IORef hiding (atomicModifyIORef)
import GHC.IORef (IORef (..))
import GHC.STRef
import GHC.Exts hiding ((==#))
import qualified GHC.PrimopWrappers as GPW
Expand All @@ -60,8 +60,6 @@ import GHC.IO (IO(IO))

#ifdef DEBUG_ATOMICS
#warning "Activating DEBUG_ATOMICS... NOINLINE's and more"
{-# NOINLINE seal #-}

{-# NOINLINE casIORef #-}
{-# NOINLINE casArrayElem2 #-}
{-# NOINLINE readArrayElem #-}
Expand Down Expand Up @@ -122,11 +120,8 @@ casArrayElem2 (MutableArray arr#) (I# i#) old new = IO$ \s1# ->

-- | Ordinary processor load instruction (non-atomic, not implying any memory barriers).
readArrayElem :: forall a . MutableArray RealWorld a -> Int -> IO (Ticket a)
-- readArrayElem = unsafeCoerce# readArray#
readArrayElem (MutableArray arr#) (I# i#) = IO $ \ st -> unsafeCoerce# (fn st)
where
fn :: State# RealWorld -> (# State# RealWorld, a #)
fn = readArray# arr# i#
readArrayElem (MutableArray arr#) (I# i#) = IO $ \ st ->
readArrayElem# arr# i# st

-- | Compare and swap on word-sized chunks of a byte-array. For indexing purposes
-- the bytearray is treated as an array of words (`Int`s). Note that UNLIKE
Expand All @@ -147,7 +142,7 @@ casByteArrayInt (MutableByteArray mba#) (I# ix#) (I# old#) (I# new#) =
-- case casByteArrayInt# mba# ix# old# new# s1# of
-- (# s2#, x#, res #) -> (# s2#, (x# ==# 0#, I# res) #)

let (# s2#, res #) = casIntArray# mba# ix# old# new# s1# in
let !(# s2#, res #) = casIntArray# mba# ix# old# new# s1# in
(# s2#, (I# res) #)
-- I don't know if a let will mak any difference here... hopefully not.

Expand All @@ -162,7 +157,7 @@ fetchAddIntArray :: MutableByteArray RealWorld
-> Int -- ^ The value to be added
-> IO Int -- ^ The value *before* the addition
fetchAddIntArray (MutableByteArray mba#) (I# offset#) (I# incr#) = IO $ \ s1# ->
let (# s2#, res #) = fetchAddIntArray# mba# offset# incr# s1# in
let !(# s2#, res #) = fetchAddIntArray# mba# offset# incr# s1# in
(# s2#, (I# res) #)


Expand Down Expand Up @@ -215,7 +210,7 @@ doAtomicRMW :: (MutableByteArray# RealWorld -> Int# -> Int# -> State# RealWorld
doAtomicRMW atomicOp# =
\(MutableByteArray mba#) (I# offset#) (I# val#) ->
IO $ \ s1# ->
let (# s2#, res #) = atomicOp# mba# offset# val# s1# in
let !(# s2#, res #) = atomicOp# mba# offset# val# s1# in
(# s2#, (I# res) #)


Expand All @@ -227,7 +222,7 @@ doAtomicRMW atomicOp# =
-- such as in GCC's `__sync_add_and_fetch`.
fetchAddByteArrayInt :: MutableByteArray RealWorld -> Int -> Int -> IO Int
fetchAddByteArrayInt (MutableByteArray mba#) (I# offset#) (I# incr#) = IO $ \ s1# ->
let (# s2#, res #) = fetchAddIntArray# mba# offset# incr# s1# in
let !(# s2#, res #) = fetchAddIntArray# mba# offset# incr# s1# in
(# s2#, (I# (res +# incr#)) #)


Expand Down Expand Up @@ -301,19 +296,6 @@ casIORef2 (IORef (STRef var)) old new = casMutVar2 var old new

--------------------------------------------------------------------------------

-- | A ticket contains or can get the usable Haskell value.
-- This function does just that.
{-# NOINLINE peekTicket #-}
-- At least this function MUST remain NOINLINE. Issue5 is an example of a bug that
-- ensues otherwise.
peekTicket :: Ticket a -> a
peekTicket = unsafeCoerce#

-- Not exposing this for now. Presently the idea is that you must read from the
-- mutable data structure itself to get a ticket.
seal :: a -> Ticket a
seal = unsafeCoerce#

-- | Like `readForCAS`, but for `MutVar#`.
readMutVarForCAS :: MutVar# RealWorld a -> IO ( Ticket a )
readMutVarForCAS mv = IO$ \ st -> readForCAS# mv st
Expand Down Expand Up @@ -381,8 +363,8 @@ foreign import ccall unsafe "DUP_write_barrier" writeBarrier
-- | A drop-in replacement for `atomicModifyIORef` that
-- optimistically attempts to compute the new value and CAS it into
-- place without introducing new thunks or locking anything. Note
-- that this is more STRICT than its standard counterpart and will only
-- place evaluated (WHNF) values in the IORef.
-- that this is more STRICT than its standard counterpart; the value in the
-- 'IORef` will always be forced to WHNF before the function returns.
--
-- The upside is that sometimes we see a performance benefit.
-- The downside is that this version is speculative -- when it
Expand All @@ -396,9 +378,10 @@ atomicModifyIORefCAS ref fn = do
loop tick effort
where
effort = 30 :: Int -- TODO: Tune this.
loop _ 0 = atomicModifyIORef ref fn -- Fall back to the regular version.
loop _ 0 = atomicModifyIORef ref (\x -> let r = fn x in fst r `seq` r) -- Fall back to the regular version.
loop old tries = do
(new,result) <- evaluate $ fn $ peekTicket old
oldVal <- peekTicketA old
(new,result) <- evaluate $ fn $ oldVal
(b,tick) <- casIORef ref old new
if b
then return result
Expand All @@ -415,9 +398,10 @@ atomicModifyIORefCAS_ ref fn = do
loop tick effort
where
effort = 30 :: Int -- TODO: Tune this.
loop _ 0 = atomicModifyIORef ref (\ x -> (fn x, ()))
loop _ 0 = atomicModifyIORef ref (\ x -> let r = fn x in r `seq` (fn x, ()))
loop old tries = do
new <- evaluate $ fn $ peekTicket old
oldVal <- peekTicketA old
new <- evaluate $ fn $ oldVal
(b,val) <- casIORef ref old new
if b
then return ()
Expand Down
6 changes: 3 additions & 3 deletions atomic-primops/Data/Atomics/Counter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ peekCTicket !x = x
casCounter :: AtomicCounter -> CTicket -> Int -> IO (Bool, CTicket)
-- casCounter (AtomicCounter barr) !old !new =
casCounter (AtomicCounter mba#) (I# old#) newBox@(I# new#) = IO$ \s1# ->
let (# s2#, res# #) = casIntArray# mba# 0# old# new# s1# in
let !(# s2#, res# #) = casIntArray# mba# 0# old# new# s1# in
case res# ==# old# of
False -> (# s2#, (False, I# res# ) #) -- Failure
True -> (# s2#, (True , newBox ) #) -- Success
Expand All @@ -130,12 +130,12 @@ casCounter (AtomicCounter mba#) (I# old#) newBox@(I# new#) = IO$ \s1# ->
-- loop like CAS.
incrCounter :: Int -> AtomicCounter -> IO Int
incrCounter (I# incr#) (AtomicCounter mba#) = IO $ \ s1# ->
let (# s2#, res #) = fetchAddIntArray# mba# 0# incr# s1# in
let !(# s2#, res #) = fetchAddIntArray# mba# 0# incr# s1# in
(# s2#, (I# (res +# incr#)) #)

{-# INLINE incrCounter_ #-}
-- | An alternate version for when you don't care about the old value.
incrCounter_ :: Int -> AtomicCounter -> IO ()
incrCounter_ (I# incr#) (AtomicCounter mba#) = IO $ \ s1# ->
let (# s2#, _ #) = fetchAddIntArray# mba# 0# incr# s1# in
let !(# s2#, _ #) = fetchAddIntArray# mba# 0# incr# s1# in
(# s2#, () #)
113 changes: 75 additions & 38 deletions atomic-primops/Data/Atomics/Internal.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE CPP, TypeSynonymInstances, BangPatterns #-}
{-# LANGUAGE ForeignFunctionInterface, GHCForeignImportPrim, MagicHash, UnboxedTuples, UnliftedFFITypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

#define CASTFUN

Expand All @@ -9,75 +10,111 @@ module Data.Atomics.Internal
(
casIntArray#, fetchAddIntArray#,
readForCAS#, casMutVarTicketed#, casArrayTicketed#,
readArrayElem#,
Ticket,
-- * Very unsafe, not to be used
ptrEq
peekTicket,
peekTicketA,
seal,
-- * Very unsafe; for testing only
reallyUnsafeTicketEquality
)
where

import GHC.Exts (Int(I#), Any, RealWorld, Int#, State#, MutableArray#, MutVar#,
unsafeCoerce#, reallyUnsafePtrEquality#,
casArray#, casIntArray#, fetchAddIntArray#, readMutVar#, casMutVar#)
import GHC.Exts (Int(I#), RealWorld, Int#, State#, MutableArray#, MutVar#,
reallyUnsafePtrEquality#, readArray#,
casArray#, casIntArray#, fetchAddIntArray#, readMutVar#, casMutVar#, lazy)

#ifdef DEBUG_ATOMICS
{-# NOINLINE readForCAS# #-}
{-# NOINLINE readArrayElem# #-}
{-# NOINLINE casMutVarTicketed# #-}
{-# NOINLINE casArrayTicketed# #-}
{-# NOINLINE peekTicket #-}
{-# NOINLINE peekTicketA #-}
{-# NOINLINE seal #-}
#else
-- {-# INLINE casMutVarTicketed# #-}
{-# INLINE readForCAS# #-}
{-# INLINE readArrayElem# #-}
{-# INLINE casMutVarTicketed# #-}
{-# INLINE casArrayTicketed# #-}
-- I *think* inlining may be ok here as long as casting happens on the arrow types:
{-# INLINE peekTicket #-}
{-# INLINE peekTicketA #-}
{-# INLINE seal #-}
#endif

--------------------------------------------------------------------------------
-- CAS and friends
--------------------------------------------------------------------------------

-- | Unsafe, machine-level atomic compare and swap on an element within an Array.
casArrayTicketed# :: MutableArray# RealWorld a -> Int# -> Ticket a -> Ticket a
casArrayTicketed# :: forall a. MutableArray# RealWorld a -> Int# -> Ticket a -> Ticket a
-> State# RealWorld -> (# State# RealWorld, Int#, Ticket a #)
-- WARNING: cast of a function -- need to verify these are safe or eta expand.
casArrayTicketed# = unsafeCoerce# casArray#
casArrayTicketed# arr i (Ticket old) (Ticket new) s =
case casArray# arr i old new s of
(# s', flag, a #) -> (# s', flag, Ticket a #)

-- | When performing compare-and-swaps, the /ticket/ encapsulates proof
-- that a thread observed a specific previous value of a mutable
-- variable. It is provided in lieu of the "old" value to
-- compare-and-swap.
-- | Ordinary processor load instruction (non-atomic, not implying any memory barriers).
readArrayElem# :: forall a . MutableArray# RealWorld a -> Int#
-> State# RealWorld -> (# State# RealWorld, Ticket a #)
readArrayElem# arr i s =
case readArray# arr i s of
(# s', a #) -> (# s', Ticket a #)

-- | When performing compare-and-swaps, the /ticket/ encapsulates proof that a
-- thread observed a specific previous value of a mutable variable. It is
-- provided in lieu of the "old" value to compare-and-swap.
--
-- Design note: `Ticket`s exist to hide objects from the GHC compiler, which
-- can normally perform many optimizations that change pointer equality. A Ticket,
-- on the other hand, is a first-class object that can be handled by the user,
-- but will not have its pointer identity changed by compiler optimizations
-- (but will of course, change addresses during garbage collection).
newtype Ticket a = Ticket Any
-- If we allow tickets to be a pointer type, then the garbage collector will update
-- the pointer when the object moves.
-- can normally perform many optimizations that change pointer equality. A
-- Ticket, on the other hand, is a first-class object that can be handled by
-- the user, but will not have its pointer identity changed by compiler
-- optimizations (but will of course, change addresses during garbage
-- collection).
data Ticket a = Ticket a
-- If we allow tickets to be a pointer type, then the garbage collector will
-- update the pointer when the object moves.

-- | Wrap up a Haskell value in a ticket. This is not exposed "publicly" for
-- now. Presently the idea is that you must read from the mutable data
-- structure itself to get a ticket.
seal :: a -> Ticket a
seal = Ticket

-- | Extract a usable Haskell value from a ticket. In many cases, it is better
-- to use 'peekTicketA', to ensure the ticket is unboxed. @peekTicket@
-- works fine in strict contexts, however.
peekTicket :: Ticket a -> a
-- We use 'lazy' to guarantee that GHC's strictness analysis won't
-- force the ticket contents too early if it sees that the result of
-- `peekTicket` is eventually forced.
peekTicket (Ticket a) = lazy a

-- | Extract a usable Haskell value from a ticket.
peekTicketA :: Applicative f => Ticket a -> f a
-- We use 'lazy' to guarantee that GHC's strictness analysis won't
-- force the ticket contents too early if it sees that the result of
-- `peekTicket#` is eventually forced.
peekTicketA (Ticket a) = pure (lazy a)

instance Show (Ticket a) where
show _ = "<CAS_ticket>"

{-# NOINLINE ptrEq #-}
ptrEq :: a -> a -> Bool
ptrEq !x !y = I# (reallyUnsafePtrEquality# x y) == 1

instance Eq (Ticket a) where
(==) = ptrEq
-- | Check whether the contents of two tickets are the
-- same pointer. This is used only for testing.
reallyUnsafeTicketEquality :: Ticket a -> Ticket a -> Bool
reallyUnsafeTicketEquality (Ticket x) (Ticket y) = I# (reallyUnsafePtrEquality# x y) == 1

--------------------------------------------------------------------------------

readForCAS# :: MutVar# RealWorld a ->
readForCAS# :: forall a. MutVar# RealWorld a ->
State# RealWorld -> (# State# RealWorld, Ticket a #)
-- WARNING: cast of a function -- need to verify these are safe or eta expand:
#ifdef CASTFUN
readForCAS# = unsafeCoerce# readMutVar#
#else
readForCAS# mv rw =
case readMutVar# mv rw of
(# rw', a #) -> (# rw', unsafeCoerce# a #)
#endif
readForCAS# ref s = case readMutVar# ref s of
(# s', a #) -> (# s', Ticket a #)


casMutVarTicketed# :: MutVar# RealWorld a -> Ticket a -> Ticket a ->
casMutVarTicketed# :: forall a. MutVar# RealWorld a -> Ticket a -> Ticket a ->
State# RealWorld -> (# State# RealWorld, Int#, Ticket a #)
-- WARNING: cast of a function -- need to verify these are safe or eta expand:
casMutVarTicketed# = unsafeCoerce# casMutVar#
casMutVarTicketed# ref (Ticket old) (Ticket new) s =
case casMutVar# ref old new s of
(# s', flag, a #) -> (# s', flag, Ticket a #)
3 changes: 2 additions & 1 deletion atomic-primops/testing/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import System.Mem (performGC)

----------------------------------------
import Data.Atomics as A
import Data.Atomics.Internal (reallyUnsafeTicketEquality)

import qualified Issue28

Expand Down Expand Up @@ -158,7 +159,7 @@ test_random_array_comm threads size iters = do
tick0 <- A.readArrayElem arr 0
for_ 1 size $ \ i -> do
t2 <- A.readArrayElem arr i
assertEqual "All initial Nothings in the array should be ticket-equal:" tick0 t2
assertBool "All initial Nothings in the array should be ticket-equal:" (reallyUnsafeTicketEquality tick0 t2)

ls <- forkJoin threads $ \_tid -> do
localAcc <- newIORef 0
Expand Down