{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}

-- |
-- Module      : Test.DejaFu.SCT.Internal.Weighted
-- Copyright   : (c) 2015--2019 Michael Walker
-- License     : MIT
-- Maintainer  : Michael Walker <mike@barrucadu.co.uk>
-- Stability   : experimental
-- Portability : DeriveAnyClass, DeriveGeneric
--
-- Internal types and functions for SCT via weighted random
-- scheduling.  This module is NOT considered to form part of the
-- public interface of this library.
module Test.DejaFu.SCT.Internal.Weighted where

import           Control.DeepSeq      (NFData)
import           Data.List.NonEmpty   (toList)
import           Data.Map.Strict      (Map)
import qualified Data.Map.Strict      as M
import           GHC.Generics         (Generic)
import           System.Random        (RandomGen, randomR)

import           Test.DejaFu.Schedule (Scheduler(..))
import           Test.DejaFu.Types

-------------------------------------------------------------------------------
-- * Weighted random scheduler

-- | The scheduler state
data RandSchedState g = RandSchedState
  { forall g. RandSchedState g -> Map ThreadId Int
schedWeights :: Map ThreadId Int
  -- ^ The thread weights: used in determining which to run.
  , forall g. RandSchedState g -> Maybe LengthBound
schedLengthBound :: Maybe LengthBound
  -- ^ The optional length bound.
  , forall g. RandSchedState g -> g
schedGen :: g
  -- ^ The random number generator.
  } deriving (RandSchedState g -> RandSchedState g -> Bool
(RandSchedState g -> RandSchedState g -> Bool)
-> (RandSchedState g -> RandSchedState g -> Bool)
-> Eq (RandSchedState g)
forall g. Eq g => RandSchedState g -> RandSchedState g -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall g. Eq g => RandSchedState g -> RandSchedState g -> Bool
== :: RandSchedState g -> RandSchedState g -> Bool
$c/= :: forall g. Eq g => RandSchedState g -> RandSchedState g -> Bool
/= :: RandSchedState g -> RandSchedState g -> Bool
Eq, Int -> RandSchedState g -> ShowS
[RandSchedState g] -> ShowS
RandSchedState g -> String
(Int -> RandSchedState g -> ShowS)
-> (RandSchedState g -> String)
-> ([RandSchedState g] -> ShowS)
-> Show (RandSchedState g)
forall g. Show g => Int -> RandSchedState g -> ShowS
forall g. Show g => [RandSchedState g] -> ShowS
forall g. Show g => RandSchedState g -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall g. Show g => Int -> RandSchedState g -> ShowS
showsPrec :: Int -> RandSchedState g -> ShowS
$cshow :: forall g. Show g => RandSchedState g -> String
show :: RandSchedState g -> String
$cshowList :: forall g. Show g => [RandSchedState g] -> ShowS
showList :: [RandSchedState g] -> ShowS
Show, (forall x. RandSchedState g -> Rep (RandSchedState g) x)
-> (forall x. Rep (RandSchedState g) x -> RandSchedState g)
-> Generic (RandSchedState g)
forall x. Rep (RandSchedState g) x -> RandSchedState g
forall x. RandSchedState g -> Rep (RandSchedState g) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall g x. Rep (RandSchedState g) x -> RandSchedState g
forall g x. RandSchedState g -> Rep (RandSchedState g) x
$cfrom :: forall g x. RandSchedState g -> Rep (RandSchedState g) x
from :: forall x. RandSchedState g -> Rep (RandSchedState g) x
$cto :: forall g x. Rep (RandSchedState g) x -> RandSchedState g
to :: forall x. Rep (RandSchedState g) x -> RandSchedState g
Generic, RandSchedState g -> ()
(RandSchedState g -> ()) -> NFData (RandSchedState g)
forall g. NFData g => RandSchedState g -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall g. NFData g => RandSchedState g -> ()
rnf :: RandSchedState g -> ()
NFData)

-- | Initial weighted random scheduler state.
initialRandSchedState :: Maybe LengthBound -> g -> RandSchedState g
initialRandSchedState :: forall g. Maybe LengthBound -> g -> RandSchedState g
initialRandSchedState = Map ThreadId Int -> Maybe LengthBound -> g -> RandSchedState g
forall g.
Map ThreadId Int -> Maybe LengthBound -> g -> RandSchedState g
RandSchedState Map ThreadId Int
forall k a. Map k a
M.empty

-- | Weighted random scheduler: assigns to each new thread a weight,
-- and makes a weighted random choice out of the runnable threads at
-- every step.
randSched :: RandomGen g => (g -> (Int, g)) -> Scheduler (RandSchedState g)
randSched :: forall g.
RandomGen g =>
(g -> (Int, g)) -> Scheduler (RandSchedState g)
randSched g -> (Int, g)
weightf = (Maybe (ThreadId, ThreadAction)
 -> NonEmpty (ThreadId, Lookahead)
 -> ConcurrencyState
 -> RandSchedState g
 -> (Maybe ThreadId, RandSchedState g))
-> Scheduler (RandSchedState g)
forall state.
(Maybe (ThreadId, ThreadAction)
 -> NonEmpty (ThreadId, Lookahead)
 -> ConcurrencyState
 -> state
 -> (Maybe ThreadId, state))
-> Scheduler state
Scheduler ((Maybe (ThreadId, ThreadAction)
  -> NonEmpty (ThreadId, Lookahead)
  -> ConcurrencyState
  -> RandSchedState g
  -> (Maybe ThreadId, RandSchedState g))
 -> Scheduler (RandSchedState g))
-> (Maybe (ThreadId, ThreadAction)
    -> NonEmpty (ThreadId, Lookahead)
    -> ConcurrencyState
    -> RandSchedState g
    -> (Maybe ThreadId, RandSchedState g))
-> Scheduler (RandSchedState g)
forall a b. (a -> b) -> a -> b
$ \Maybe (ThreadId, ThreadAction)
_ NonEmpty (ThreadId, Lookahead)
threads ConcurrencyState
_ RandSchedState g
s ->
  let
    -- Select a thread
    pick :: t -> [(a, t)] -> Maybe a
pick t
idx ((a
x, t
f):[(a, t)]
xs)
      | t
idx t -> t -> Bool
forall a. Ord a => a -> a -> Bool
< t
f = a -> Maybe a
forall a. a -> Maybe a
Just a
x
      | Bool
otherwise = t -> [(a, t)] -> Maybe a
pick (t
idx t -> t -> t
forall a. Num a => a -> a -> a
- t
f) [(a, t)]
xs
    pick t
_ [] = Maybe a
forall a. Maybe a
Nothing
    (Int
choice, g
g'') = (Int, Int) -> g -> (Int, g)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int
0, [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (((ThreadId, Int) -> Int) -> [(ThreadId, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ThreadId, Int) -> Int
forall a b. (a, b) -> b
snd [(ThreadId, Int)]
enabled) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) g
g'
    enabled :: [(ThreadId, Int)]
enabled = Map ThreadId Int -> [(ThreadId, Int)]
forall k a. Map k a -> [(k, a)]
M.toList (Map ThreadId Int -> [(ThreadId, Int)])
-> Map ThreadId Int -> [(ThreadId, Int)]
forall a b. (a -> b) -> a -> b
$ (ThreadId -> Int -> Bool) -> Map ThreadId Int -> Map ThreadId Int
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\ThreadId
tid Int
_ -> ThreadId
tid ThreadId -> [ThreadId] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ThreadId]
tids) Map ThreadId Int
weights'

    -- The weights, with any new threads added.
    (Map ThreadId Int
weights', g
g') = (ThreadId -> (Map ThreadId Int, g) -> (Map ThreadId Int, g))
-> (Map ThreadId Int, g) -> [ThreadId] -> (Map ThreadId Int, g)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ThreadId -> (Map ThreadId Int, g) -> (Map ThreadId Int, g)
assignWeight (Map ThreadId Int
forall k a. Map k a
M.empty, RandSchedState g -> g
forall g. RandSchedState g -> g
schedGen RandSchedState g
s) [ThreadId]
tids
    assignWeight :: ThreadId -> (Map ThreadId Int, g) -> (Map ThreadId Int, g)
assignWeight ThreadId
tid ~(Map ThreadId Int
ws, g
g0) =
      let (Int
w, g
g) = (Int, g) -> (Int -> (Int, g)) -> Maybe Int -> (Int, g)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (g -> (Int, g)
weightf g
g0) (\Int
w0 -> (Int
w0, g
g0)) (ThreadId -> Map ThreadId Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ThreadId
tid (RandSchedState g -> Map ThreadId Int
forall g. RandSchedState g -> Map ThreadId Int
schedWeights RandSchedState g
s))
      in (ThreadId -> Int -> Map ThreadId Int -> Map ThreadId Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert ThreadId
tid Int
w Map ThreadId Int
ws, g
g)

    -- The runnable threads.
    tids :: [ThreadId]
tids = ((ThreadId, Lookahead) -> ThreadId)
-> [(ThreadId, Lookahead)] -> [ThreadId]
forall a b. (a -> b) -> [a] -> [b]
map (ThreadId, Lookahead) -> ThreadId
forall a b. (a, b) -> a
fst (NonEmpty (ThreadId, Lookahead) -> [(ThreadId, Lookahead)]
forall a. NonEmpty a -> [a]
toList NonEmpty (ThreadId, Lookahead)
threads)
  in case RandSchedState g -> Maybe LengthBound
forall g. RandSchedState g -> Maybe LengthBound
schedLengthBound RandSchedState g
s of
    Just LengthBound
0 -> (Maybe ThreadId
forall a. Maybe a
Nothing, RandSchedState g
s)
    Just LengthBound
n -> (Int -> [(ThreadId, Int)] -> Maybe ThreadId
forall {t} {a}. (Ord t, Num t) => t -> [(a, t)] -> Maybe a
pick Int
choice [(ThreadId, Int)]
enabled, Map ThreadId Int -> Maybe LengthBound -> g -> RandSchedState g
forall g.
Map ThreadId Int -> Maybe LengthBound -> g -> RandSchedState g
RandSchedState Map ThreadId Int
weights' (LengthBound -> Maybe LengthBound
forall a. a -> Maybe a
Just (LengthBound
n LengthBound -> LengthBound -> LengthBound
forall a. Num a => a -> a -> a
- LengthBound
1)) g
g'')
    Maybe LengthBound
Nothing -> (Int -> [(ThreadId, Int)] -> Maybe ThreadId
forall {t} {a}. (Ord t, Num t) => t -> [(a, t)] -> Maybe a
pick Int
choice [(ThreadId, Int)]
enabled, Map ThreadId Int -> Maybe LengthBound -> g -> RandSchedState g
forall g.
Map ThreadId Int -> Maybe LengthBound -> g -> RandSchedState g
RandSchedState Map ThreadId Int
weights' Maybe LengthBound
forall a. Maybe a
Nothing g
g'')