{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}

-- |
-- Module      : Test.DejaFu.Conc.Internal.Threading
-- Copyright   : (c) 2016--2020 Michael Walker
-- License     : MIT
-- Maintainer  : Michael Walker <mike@barrucadu.co.uk>
-- Stability   : experimental
-- Portability : ExistentialQuantification, FlexibleContexts, RankNTypes
--
-- Operations and types for threads. This module is NOT considered to
-- form part of the public interface of this library.
module Test.DejaFu.Conc.Internal.Threading where

import           Control.Exception                (Exception, MaskingState(..),
                                                   SomeException, fromException)
import           Data.List                        (intersect)
import           Data.Map.Strict                  (Map)
import qualified Data.Map.Strict                  as M
import           Data.Maybe                       (isJust)
import           GHC.Stack                        (HasCallStack)

import           Test.DejaFu.Conc.Internal.Common
import           Test.DejaFu.Internal
import           Test.DejaFu.Types

--------------------------------------------------------------------------------
-- * Threads

-- | Threads are stored in a map index by 'ThreadId'.
type Threads n = Map ThreadId (Thread n)

-- | All the state of a thread.
data Thread n = Thread
  { forall (n :: * -> *). Thread n -> Action n
_continuation :: Action n
  -- ^ The next action to execute.
  , forall (n :: * -> *). Thread n -> Maybe BlockedOn
_blocking     :: Maybe BlockedOn
  -- ^ The state of any blocks.
  , forall (n :: * -> *). Thread n -> [Handler n]
_handlers     :: [Handler n]
  -- ^ Stack of exception handlers
  , forall (n :: * -> *). Thread n -> MaskingState
_masking      :: MaskingState
  -- ^ The exception masking state.
  , forall (n :: * -> *). Thread n -> Maybe (BoundThread n (Action n))
_bound        :: Maybe (BoundThread n (Action n))
  -- ^ State for the associated bound thread, if it exists.
  }

-- | Construct a thread with just one action
mkthread :: Action n -> Thread n
mkthread :: forall (n :: * -> *). Action n -> Thread n
mkthread Action n
c = forall (n :: * -> *).
Action n
-> Maybe BlockedOn
-> [Handler n]
-> MaskingState
-> Maybe (BoundThread n (Action n))
-> Thread n
Thread Action n
c forall a. Maybe a
Nothing [] MaskingState
Unmasked forall a. Maybe a
Nothing

--------------------------------------------------------------------------------
-- * Blocking

-- | A @BlockedOn@ is used to determine what sort of variable a thread
-- is blocked on.
data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving BlockedOn -> BlockedOn -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: BlockedOn -> BlockedOn -> Bool
$c/= :: BlockedOn -> BlockedOn -> Bool
== :: BlockedOn -> BlockedOn -> Bool
$c== :: BlockedOn -> BlockedOn -> Bool
Eq

-- | Determine if a thread is blocked in a certain way.
(~=) :: Thread n -> BlockedOn -> Bool
Thread n
thread ~= :: forall (n :: * -> *). Thread n -> BlockedOn -> Bool
~= BlockedOn
theblock = case (forall (n :: * -> *). Thread n -> Maybe BlockedOn
_blocking Thread n
thread, BlockedOn
theblock) of
  (Just (OnMVarFull  MVarId
_), OnMVarFull  MVarId
_) -> Bool
True
  (Just (OnMVarEmpty MVarId
_), OnMVarEmpty MVarId
_) -> Bool
True
  (Just (OnTVar      [TVarId]
_), OnTVar      [TVarId]
_) -> Bool
True
  (Just (OnMask      ThreadId
_), OnMask      ThreadId
_) -> Bool
True
  (Maybe BlockedOn, BlockedOn)
_ -> Bool
False

--------------------------------------------------------------------------------
-- * Exceptions

-- | An exception handler.
data Handler n = forall e. Exception e => Handler MaskingState (e -> Action n)

-- | Propagate an exception upwards, finding the closest handler
-- which can deal with it.
propagate :: HasCallStack => SomeException -> ThreadId -> Threads n -> Maybe (Threads n)
propagate :: forall (n :: * -> *).
HasCallStack =>
SomeException -> ThreadId -> Threads n -> Maybe (Threads n)
propagate SomeException
e ThreadId
tid Threads n
threads = (MaskingState, Action n, [Handler n]) -> Threads n
raise forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {n :: * -> *}.
[Handler n] -> Maybe (MaskingState, Action n, [Handler n])
propagate' [Handler n]
handlers where
  handlers :: [Handler n]
handlers = forall (n :: * -> *). Thread n -> [Handler n]
_handlers (forall k v. (Ord k, Show k, HasCallStack) => k -> Map k v -> v
elookup ThreadId
tid Threads n
threads)

  raise :: (MaskingState, Action n, [Handler n]) -> Threads n
raise (MaskingState
ms, Action n
act, [Handler n]
hs) = forall (n :: * -> *).
HasCallStack =>
MaskingState
-> Action n -> [Handler n] -> ThreadId -> Threads n -> Threads n
except MaskingState
ms Action n
act [Handler n]
hs ThreadId
tid Threads n
threads

  propagate' :: [Handler n] -> Maybe (MaskingState, Action n, [Handler n])
propagate' [] = forall a. Maybe a
Nothing
  propagate' (Handler MaskingState
ms e -> Action n
h:[Handler n]
hs) = forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Handler n] -> Maybe (MaskingState, Action n, [Handler n])
propagate' [Handler n]
hs) ((\Action n
act -> forall a. a -> Maybe a
Just (MaskingState
ms, Action n
act, [Handler n]
hs)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> Action n
h) (forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e)

-- | Check if a thread can be interrupted by an exception.
interruptible :: Thread n -> Bool
interruptible :: forall (n :: * -> *). Thread n -> Bool
interruptible Thread n
thread =
  forall (n :: * -> *). Thread n -> MaskingState
_masking Thread n
thread forall a. Eq a => a -> a -> Bool
== MaskingState
Unmasked Bool -> Bool -> Bool
||
  (forall (n :: * -> *). Thread n -> MaskingState
_masking Thread n
thread forall a. Eq a => a -> a -> Bool
== MaskingState
MaskedInterruptible Bool -> Bool -> Bool
&& forall a. Maybe a -> Bool
isJust (forall (n :: * -> *). Thread n -> Maybe BlockedOn
_blocking Thread n
thread))

-- | Register a new exception handler.
catching :: (Exception e, HasCallStack) => (e -> Action n) -> ThreadId -> Threads n -> Threads n
catching :: forall e (n :: * -> *).
(Exception e, HasCallStack) =>
(e -> Action n) -> ThreadId -> Threads n -> Threads n
catching e -> Action n
h = forall k v.
(Ord k, Show k, HasCallStack) =>
(v -> v) -> k -> Map k v -> Map k v
eadjust forall a b. (a -> b) -> a -> b
$ \Thread n
thread ->
  let ms0 :: MaskingState
ms0 = forall (n :: * -> *). Thread n -> MaskingState
_masking Thread n
thread
      h' :: Handler n
h'  = forall (n :: * -> *) e.
Exception e =>
MaskingState -> (e -> Action n) -> Handler n
Handler MaskingState
ms0 e -> Action n
h
  in Thread n
thread { _handlers :: [Handler n]
_handlers = Handler n
h' forall a. a -> [a] -> [a]
: forall (n :: * -> *). Thread n -> [Handler n]
_handlers Thread n
thread }

-- | Remove the most recent exception handler.
uncatching :: HasCallStack => ThreadId -> Threads n -> Threads n
uncatching :: forall (n :: * -> *).
HasCallStack =>
ThreadId -> Threads n -> Threads n
uncatching = forall k v.
(Ord k, Show k, HasCallStack) =>
(v -> v) -> k -> Map k v -> Map k v
eadjust forall a b. (a -> b) -> a -> b
$ \Thread n
thread ->
  Thread n
thread { _handlers :: [Handler n]
_handlers = forall a. HasCallStack => [a] -> [a]
etail (forall (n :: * -> *). Thread n -> [Handler n]
_handlers Thread n
thread) }

-- | Raise an exception in a thread.
except :: HasCallStack => MaskingState -> Action n -> [Handler n] -> ThreadId -> Threads n -> Threads n
except :: forall (n :: * -> *).
HasCallStack =>
MaskingState
-> Action n -> [Handler n] -> ThreadId -> Threads n -> Threads n
except MaskingState
ms Action n
act [Handler n]
hs = forall k v.
(Ord k, Show k, HasCallStack) =>
(v -> v) -> k -> Map k v -> Map k v
eadjust forall a b. (a -> b) -> a -> b
$ \Thread n
thread -> Thread n
thread
  { _continuation :: Action n
_continuation = Action n
act
  , _masking :: MaskingState
_masking = MaskingState
ms
  , _handlers :: [Handler n]
_handlers = [Handler n]
hs
  , _blocking :: Maybe BlockedOn
_blocking = forall a. Maybe a
Nothing
  }

-- | Set the masking state of a thread.
mask :: HasCallStack => MaskingState -> ThreadId -> Threads n -> Threads n
mask :: forall (n :: * -> *).
HasCallStack =>
MaskingState -> ThreadId -> Threads n -> Threads n
mask MaskingState
ms = forall k v.
(Ord k, Show k, HasCallStack) =>
(v -> v) -> k -> Map k v -> Map k v
eadjust forall a b. (a -> b) -> a -> b
$ \Thread n
thread -> Thread n
thread { _masking :: MaskingState
_masking = MaskingState
ms }

--------------------------------------------------------------------------------
-- * Manipulating threads

-- | Replace the @Action@ of a thread.
goto :: HasCallStack => Action n -> ThreadId -> Threads n -> Threads n
goto :: forall (n :: * -> *).
HasCallStack =>
Action n -> ThreadId -> Threads n -> Threads n
goto Action n
a = forall k v.
(Ord k, Show k, HasCallStack) =>
(v -> v) -> k -> Map k v -> Map k v
eadjust forall a b. (a -> b) -> a -> b
$ \Thread n
thread -> Thread n
thread { _continuation :: Action n
_continuation = Action n
a }

-- | Start a thread with the given ID, inheriting the masking state
-- from the parent thread. This ID must not already be in use!
launch :: HasCallStack => ThreadId -> ThreadId -> ((forall b. ModelConc n b -> ModelConc n b) -> Action n) -> Threads n -> Threads n
launch :: forall (n :: * -> *).
HasCallStack =>
ThreadId
-> ThreadId
-> ((forall b. ModelConc n b -> ModelConc n b) -> Action n)
-> Threads n
-> Threads n
launch ThreadId
parent ThreadId
tid (forall b. ModelConc n b -> ModelConc n b) -> Action n
a Threads n
threads = forall (n :: * -> *).
HasCallStack =>
MaskingState
-> ThreadId
-> ((forall b. ModelConc n b -> ModelConc n b) -> Action n)
-> Threads n
-> Threads n
launch' MaskingState
ms ThreadId
tid (forall b. ModelConc n b -> ModelConc n b) -> Action n
a Threads n
threads where
  ms :: MaskingState
ms = forall (n :: * -> *). Thread n -> MaskingState
_masking (forall k v. (Ord k, Show k, HasCallStack) => k -> Map k v -> v
elookup ThreadId
parent Threads n
threads)

-- | Start a thread with the given ID and masking state. This must not already be in use!
launch' :: HasCallStack => MaskingState -> ThreadId -> ((forall b. ModelConc n b -> ModelConc n b) -> Action n) -> Threads n -> Threads n
launch' :: forall (n :: * -> *).
HasCallStack =>
MaskingState
-> ThreadId
-> ((forall b. ModelConc n b -> ModelConc n b) -> Action n)
-> Threads n
-> Threads n
launch' MaskingState
ms ThreadId
tid (forall b. ModelConc n b -> ModelConc n b) -> Action n
a = forall k v.
(Ord k, Show k, HasCallStack) =>
k -> v -> Map k v -> Map k v
einsert ThreadId
tid Thread n
thread where
  thread :: Thread n
thread = forall (n :: * -> *).
Action n
-> Maybe BlockedOn
-> [Handler n]
-> MaskingState
-> Maybe (BoundThread n (Action n))
-> Thread n
Thread ((forall b. ModelConc n b -> ModelConc n b) -> Action n
a forall {n :: * -> *} {b}. Program Basic n b -> Program Basic n b
umask) forall a. Maybe a
Nothing [] MaskingState
ms forall a. Maybe a
Nothing

  umask :: Program Basic n b -> Program Basic n b
umask Program Basic n b
mb = forall {n :: * -> *}. Bool -> MaskingState -> Program Basic n ()
resetMask Bool
True MaskingState
Unmasked forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Program Basic n b
mb forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \b
b -> forall {n :: * -> *}. Bool -> MaskingState -> Program Basic n ()
resetMask Bool
False MaskingState
ms forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure b
b
  resetMask :: Bool -> MaskingState -> Program Basic n ()
resetMask Bool
typ MaskingState
m = forall a (n :: * -> *).
((a -> Action n) -> Action n) -> Program Basic n a
ModelConc forall a b. (a -> b) -> a -> b
$ \() -> Action n
k -> forall (n :: * -> *).
Bool -> Bool -> MaskingState -> Action n -> Action n
AResetMask Bool
typ Bool
True MaskingState
m forall a b. (a -> b) -> a -> b
$ () -> Action n
k ()

-- | Block a thread.
block :: HasCallStack => BlockedOn -> ThreadId -> Threads n -> Threads n
block :: forall (n :: * -> *).
HasCallStack =>
BlockedOn -> ThreadId -> Threads n -> Threads n
block BlockedOn
blockedOn = forall k v.
(Ord k, Show k, HasCallStack) =>
(v -> v) -> k -> Map k v -> Map k v
eadjust forall a b. (a -> b) -> a -> b
$ \Thread n
thread -> Thread n
thread { _blocking :: Maybe BlockedOn
_blocking = forall a. a -> Maybe a
Just BlockedOn
blockedOn }

-- | Unblock all threads waiting on the appropriate block. For 'TVar'
-- blocks, this will wake all threads waiting on at least one of the
-- given 'TVar's.
wake :: BlockedOn -> Threads n -> (Threads n, [ThreadId])
wake :: forall (n :: * -> *).
BlockedOn -> Threads n -> (Threads n, [ThreadId])
wake BlockedOn
blockedOn Threads n
threads = (forall {n :: * -> *}. Thread n -> Thread n
unblock forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Threads n
threads, forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter forall (n :: * -> *). Thread n -> Bool
isBlocked Threads n
threads) where
  unblock :: Thread n -> Thread n
unblock Thread n
thread
    | forall (n :: * -> *). Thread n -> Bool
isBlocked Thread n
thread = Thread n
thread { _blocking :: Maybe BlockedOn
_blocking = forall a. Maybe a
Nothing }
    | Bool
otherwise = Thread n
thread

  isBlocked :: Thread n -> Bool
isBlocked Thread n
thread = case (forall (n :: * -> *). Thread n -> Maybe BlockedOn
_blocking Thread n
thread, BlockedOn
blockedOn) of
    (Just (OnTVar [TVarId]
tvids), OnTVar [TVarId]
blockedOn') -> [TVarId]
tvids forall a. Eq a => [a] -> [a] -> [a]
`intersect` [TVarId]
blockedOn' forall a. Eq a => a -> a -> Bool
/= []
    (Maybe BlockedOn
theblock, BlockedOn
_) -> Maybe BlockedOn
theblock forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just BlockedOn
blockedOn

-------------------------------------------------------------------------------
-- ** Bound threads

-- | Turn a thread into a bound thread.
makeBound :: (MonadDejaFu n, HasCallStack)
  => n (BoundThread n (Action n)) -> ThreadId -> Threads n -> n (Threads n)
makeBound :: forall (n :: * -> *).
(MonadDejaFu n, HasCallStack) =>
n (BoundThread n (Action n))
-> ThreadId -> Threads n -> n (Threads n)
makeBound n (BoundThread n (Action n))
fbt ThreadId
tid Threads n
threads = do
  BoundThread n (Action n)
bt <- n (BoundThread n (Action n))
fbt
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall k v.
(Ord k, Show k, HasCallStack) =>
(v -> v) -> k -> Map k v -> Map k v
eadjust (\Thread n
t -> Thread n
t { _bound :: Maybe (BoundThread n (Action n))
_bound = forall a. a -> Maybe a
Just BoundThread n (Action n)
bt }) ThreadId
tid Threads n
threads)

-- | Kill a thread and remove it from the thread map.
--
-- If the thread is bound, the worker thread is cleaned up.
kill :: (MonadDejaFu n, HasCallStack) => ThreadId -> Threads n -> n (Threads n)
kill :: forall (n :: * -> *).
(MonadDejaFu n, HasCallStack) =>
ThreadId -> Threads n -> n (Threads n)
kill ThreadId
tid Threads n
threads = do
  let thread :: Thread n
thread = forall k v. (Ord k, Show k, HasCallStack) => k -> Map k v -> v
elookup ThreadId
tid Threads n
threads
  forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) forall (m :: * -> *) a. MonadDejaFu m => BoundThread m a -> m ()
killBoundThread (forall (n :: * -> *). Thread n -> Maybe (BoundThread n (Action n))
_bound Thread n
thread)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall k a. Ord k => k -> Map k a -> Map k a
M.delete ThreadId
tid Threads n
threads)
-- | Run an action.
--
-- If the thread is bound, the action is run in the worker thread.
runLiftedAct :: MonadDejaFu n => ThreadId -> Threads n -> n (Action n) -> n (Action n)
runLiftedAct :: forall (n :: * -> *).
MonadDejaFu n =>
ThreadId -> Threads n -> n (Action n) -> n (Action n)
runLiftedAct ThreadId
tid Threads n
threads n (Action n)
ma = case forall (n :: * -> *). Thread n -> Maybe (BoundThread n (Action n))
_bound forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ThreadId
tid Threads n
threads of
  Just BoundThread n (Action n)
bt -> forall (m :: * -> *) a.
MonadDejaFu m =>
BoundThread m a -> m a -> m a
runInBoundThread BoundThread n (Action n)
bt n (Action n)
ma
  Maybe (BoundThread n (Action n))
Nothing -> n (Action n)
ma