{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE UndecidableInstances #-}
module Control.Concurrent.Classy.RWLock
(
RWLock
, newRWLock
, newAcquiredRead
, newAcquiredWrite
, acquireRead
, releaseRead
, withRead
, waitRead
, tryAcquireRead
, tryWithRead
, acquireWrite
, releaseWrite
, withWrite
, waitWrite
, tryAcquireWrite
, tryWithWrite
) where
import Control.Applicative (pure, (<*>))
import Control.Monad (Monad, (>>))
import Data.Bool (Bool(False, True))
import Data.Eq (Eq, (==))
import Data.Function (on, ($))
import Data.Functor ((<$>))
import Data.Int (Int)
import Data.List ((++))
import Data.Maybe (Maybe(Just, Nothing))
import Data.Ord (Ord)
import Data.Typeable (Typeable)
import Prelude (String, error, pred, succ)
import Text.Read (Read)
import Text.Show (Show)
import qualified Control.Concurrent.Classy.MVar as MVar
import Control.Monad.Catch (bracket_, mask, mask_,
onException)
import Control.Monad.Conc.Class (MonadConc(MVar))
import Control.Concurrent.Classy.Lock (Lock)
import qualified Control.Concurrent.Classy.Lock as Lock
data RWLock m
= RWLock
{ forall (m :: * -> *). RWLock m -> MVar m State
_state :: MVar m State
, forall (m :: * -> *). RWLock m -> Lock m
_readLock :: Lock m
, forall (m :: * -> *). RWLock m -> Lock m
_writeLock :: Lock m
}
deriving (Typeable)
instance (Eq (MVar m State)) => Eq (RWLock m) where
== :: RWLock m -> RWLock m -> Bool
(==) = forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall (m :: * -> *). RWLock m -> MVar m State
_state
data State
= Free
| Read !Int
| Write
deriving (State -> State -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: State -> State -> Bool
$c/= :: State -> State -> Bool
== :: State -> State -> Bool
$c== :: State -> State -> Bool
Eq, Eq State
State -> State -> Bool
State -> State -> Ordering
State -> State -> State
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: State -> State -> State
$cmin :: State -> State -> State
max :: State -> State -> State
$cmax :: State -> State -> State
>= :: State -> State -> Bool
$c>= :: State -> State -> Bool
> :: State -> State -> Bool
$c> :: State -> State -> Bool
<= :: State -> State -> Bool
$c<= :: State -> State -> Bool
< :: State -> State -> Bool
$c< :: State -> State -> Bool
compare :: State -> State -> Ordering
$ccompare :: State -> State -> Ordering
Ord, Int -> State -> ShowS
[State] -> ShowS
State -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [State] -> ShowS
$cshowList :: [State] -> ShowS
show :: State -> String
$cshow :: State -> String
showsPrec :: Int -> State -> ShowS
$cshowsPrec :: Int -> State -> ShowS
Show, ReadPrec [State]
ReadPrec State
Int -> ReadS State
ReadS [State]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [State]
$creadListPrec :: ReadPrec [State]
readPrec :: ReadPrec State
$creadPrec :: ReadPrec State
readList :: ReadS [State]
$creadList :: ReadS [State]
readsPrec :: Int -> ReadS State
$creadsPrec :: Int -> ReadS State
Read)
newRWLock :: (MonadConc m) => m (RWLock m)
newRWLock :: forall (m :: * -> *). MonadConc m => m (RWLock m)
newRWLock = do
MVar m State
state <- forall (m :: * -> *) a. MonadConc m => a -> m (MVar m a)
MVar.newMVar State
Free
Lock m
rlock <- forall (m :: * -> *). MonadConc m => m (Lock m)
Lock.newLock
forall (m :: * -> *). MVar m State -> Lock m -> Lock m -> RWLock m
RWLock MVar m State
state Lock m
rlock forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadConc m => m (Lock m)
Lock.newLock
newAcquiredRead :: (MonadConc m) => m (RWLock m)
newAcquiredRead :: forall (m :: * -> *). MonadConc m => m (RWLock m)
newAcquiredRead = do
MVar m State
state <- forall (m :: * -> *) a. MonadConc m => a -> m (MVar m a)
MVar.newMVar (Int -> State
Read Int
1)
Lock m
rlock <- forall (m :: * -> *). MonadConc m => m (Lock m)
Lock.newAcquired
forall (m :: * -> *). MVar m State -> Lock m -> Lock m -> RWLock m
RWLock MVar m State
state Lock m
rlock forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadConc m => m (Lock m)
Lock.newLock
newAcquiredWrite :: (MonadConc m) => m (RWLock m)
newAcquiredWrite :: forall (m :: * -> *). MonadConc m => m (RWLock m)
newAcquiredWrite = do
MVar m State
state <- forall (m :: * -> *) a. MonadConc m => a -> m (MVar m a)
MVar.newMVar State
Write
Lock m
rlock <- forall (m :: * -> *). MonadConc m => m (Lock m)
Lock.newLock
forall (m :: * -> *). MVar m State -> Lock m -> Lock m -> RWLock m
RWLock MVar m State
state Lock m
rlock forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadConc m => m (Lock m)
Lock.newAcquired
acquireRead :: (MonadConc m) => RWLock m -> m ()
acquireRead :: forall (m :: * -> *). MonadConc m => RWLock m -> m ()
acquireRead RWLock { MVar m State
_state :: MVar m State
_state :: forall (m :: * -> *). RWLock m -> MVar m State
_state, Lock m
_readLock :: Lock m
_readLock :: forall (m :: * -> *). RWLock m -> Lock m
_readLock, Lock m
_writeLock :: Lock m
_writeLock :: forall (m :: * -> *). RWLock m -> Lock m
_writeLock } = forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ m ()
go
where
go :: m ()
go = do
State
st <- forall (m :: * -> *) a. MonadConc m => MVar m a -> m a
MVar.takeMVar MVar m State
_state
case State
st of
State
Free -> do forall (m :: * -> *). MonadConc m => Lock m -> m ()
Lock.acquire Lock m
_readLock
forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state forall a b. (a -> b) -> a -> b
$ Int -> State
Read Int
1
(Read Int
n) -> forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state forall a b. (a -> b) -> a -> b
$ Int -> State
Read (forall a. Enum a => a -> a
succ Int
n)
State
Write -> do forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
st
forall (m :: * -> *). MonadConc m => Lock m -> m ()
Lock.wait Lock m
_writeLock
m ()
go
tryAcquireRead :: (MonadConc m) => RWLock m -> m Bool
tryAcquireRead :: forall (m :: * -> *). MonadConc m => RWLock m -> m Bool
tryAcquireRead RWLock { MVar m State
_state :: MVar m State
_state :: forall (m :: * -> *). RWLock m -> MVar m State
_state, Lock m
_readLock :: Lock m
_readLock :: forall (m :: * -> *). RWLock m -> Lock m
_readLock } = forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ forall a b. (a -> b) -> a -> b
$ do
State
st <- forall (m :: * -> *) a. MonadConc m => MVar m a -> m a
MVar.takeMVar MVar m State
_state
case State
st of
State
Free -> do forall (m :: * -> *). MonadConc m => Lock m -> m ()
Lock.acquire Lock m
_readLock
forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state forall a b. (a -> b) -> a -> b
$ Int -> State
Read Int
1
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
Read Int
n -> do forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state forall a b. (a -> b) -> a -> b
$ Int -> State
Read (forall a. Enum a => a -> a
succ Int
n)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
State
Write -> do forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
st
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
releaseRead :: (MonadConc m) => RWLock m -> m ()
releaseRead :: forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseRead RWLock { MVar m State
_state :: MVar m State
_state :: forall (m :: * -> *). RWLock m -> MVar m State
_state, Lock m
_readLock :: Lock m
_readLock :: forall (m :: * -> *). RWLock m -> Lock m
_readLock } = forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ forall a b. (a -> b) -> a -> b
$ do
State
st <- forall (m :: * -> *) a. MonadConc m => MVar m a -> m a
MVar.takeMVar MVar m State
_state
case State
st of
Read Int
1 -> do forall (m :: * -> *). MonadConc m => Lock m -> m ()
Lock.release Lock m
_readLock
forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
Free
Read Int
n -> forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state forall a b. (a -> b) -> a -> b
$ Int -> State
Read (forall a. Enum a => a -> a
pred Int
n)
State
_ -> do forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
st
forall (m :: * -> *) a. Monad m => String -> String -> m a
throw String
"releaseRead" String
"already released"
withRead :: (MonadConc m) => RWLock m -> m a -> m a
withRead :: forall (m :: * -> *) a. MonadConc m => RWLock m -> m a -> m a
withRead = forall (m :: * -> *) a c b. MonadMask m => m a -> m c -> m b -> m b
bracket_ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadConc m => RWLock m -> m ()
acquireRead forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseRead
tryWithRead :: (MonadConc m) => RWLock m -> m a -> m (Maybe a)
tryWithRead :: forall (m :: * -> *) a.
MonadConc m =>
RWLock m -> m a -> m (Maybe a)
tryWithRead RWLock m
l m a
a = forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
Bool
acquired <- forall (m :: * -> *). MonadConc m => RWLock m -> m Bool
tryAcquireRead RWLock m
l
if Bool
acquired
then do a
r <- forall a. m a -> m a
restore m a
a forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseRead RWLock m
l
forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseRead RWLock m
l
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just a
r
else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
waitRead :: (MonadConc m) => RWLock m -> m ()
waitRead :: forall (m :: * -> *). MonadConc m => RWLock m -> m ()
waitRead RWLock m
l = forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ (forall (m :: * -> *). MonadConc m => RWLock m -> m ()
acquireRead RWLock m
l forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseRead RWLock m
l)
acquireWrite :: (MonadConc m) => RWLock m -> m ()
acquireWrite :: forall (m :: * -> *). MonadConc m => RWLock m -> m ()
acquireWrite RWLock { MVar m State
_state :: MVar m State
_state :: forall (m :: * -> *). RWLock m -> MVar m State
_state, Lock m
_readLock :: Lock m
_readLock :: forall (m :: * -> *). RWLock m -> Lock m
_readLock, Lock m
_writeLock :: Lock m
_writeLock :: forall (m :: * -> *). RWLock m -> Lock m
_writeLock } = forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ m ()
go'
where
go' :: m ()
go' = do
State
st <- forall (m :: * -> *) a. MonadConc m => MVar m a -> m a
MVar.takeMVar MVar m State
_state
case State
st of
State
Free -> do forall (m :: * -> *). MonadConc m => Lock m -> m ()
Lock.acquire Lock m
_writeLock
forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
Write
Read Int
_ -> do forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
st
forall (m :: * -> *). MonadConc m => Lock m -> m ()
Lock.wait Lock m
_readLock
m ()
go'
State
Write -> do forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
st
forall (m :: * -> *). MonadConc m => Lock m -> m ()
Lock.wait Lock m
_writeLock
m ()
go'
tryAcquireWrite :: (MonadConc m) => RWLock m -> m Bool
tryAcquireWrite :: forall (m :: * -> *). MonadConc m => RWLock m -> m Bool
tryAcquireWrite RWLock { MVar m State
_state :: MVar m State
_state :: forall (m :: * -> *). RWLock m -> MVar m State
_state, Lock m
_writeLock :: Lock m
_writeLock :: forall (m :: * -> *). RWLock m -> Lock m
_writeLock } = forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ forall a b. (a -> b) -> a -> b
$ do
State
st <- forall (m :: * -> *) a. MonadConc m => MVar m a -> m a
MVar.takeMVar MVar m State
_state
case State
st of
State
Free -> do forall (m :: * -> *). MonadConc m => Lock m -> m ()
Lock.acquire Lock m
_writeLock
forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
Write
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
State
_ -> do forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
st
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
releaseWrite :: (MonadConc m) => RWLock m -> m ()
releaseWrite :: forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseWrite RWLock { MVar m State
_state :: MVar m State
_state :: forall (m :: * -> *). RWLock m -> MVar m State
_state, Lock m
_writeLock :: Lock m
_writeLock :: forall (m :: * -> *). RWLock m -> Lock m
_writeLock } = forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ forall a b. (a -> b) -> a -> b
$ do
State
st <- forall (m :: * -> *) a. MonadConc m => MVar m a -> m a
MVar.takeMVar MVar m State
_state
case State
st of
State
Write -> do forall (m :: * -> *). MonadConc m => Lock m -> m ()
Lock.release Lock m
_writeLock
forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
Free
State
_ -> do forall (m :: * -> *) a. MonadConc m => MVar m a -> a -> m ()
MVar.putMVar MVar m State
_state State
st
forall (m :: * -> *) a. Monad m => String -> String -> m a
throw String
"releaseWrite" String
"already released"
withWrite :: (MonadConc m) => RWLock m -> m a -> m a
withWrite :: forall (m :: * -> *) a. MonadConc m => RWLock m -> m a -> m a
withWrite = forall (m :: * -> *) a c b. MonadMask m => m a -> m c -> m b -> m b
bracket_ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadConc m => RWLock m -> m ()
acquireWrite forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseWrite
tryWithWrite :: (MonadConc m) => RWLock m -> m a -> m (Maybe a)
tryWithWrite :: forall (m :: * -> *) a.
MonadConc m =>
RWLock m -> m a -> m (Maybe a)
tryWithWrite RWLock m
l m a
a = forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
Bool
acquired <- forall (m :: * -> *). MonadConc m => RWLock m -> m Bool
tryAcquireWrite RWLock m
l
if Bool
acquired
then do a
r <- forall a. m a -> m a
restore m a
a forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseWrite RWLock m
l
forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseWrite RWLock m
l
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just a
r
else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
waitWrite :: (MonadConc m) => RWLock m -> m ()
waitWrite :: forall (m :: * -> *). MonadConc m => RWLock m -> m ()
waitWrite RWLock m
l = forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ (forall (m :: * -> *). MonadConc m => RWLock m -> m ()
acquireWrite RWLock m
l forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *). MonadConc m => RWLock m -> m ()
releaseWrite RWLock m
l)
throw :: (Monad m) => String -> String -> m a
throw :: forall (m :: * -> *) a. Monad m => String -> String -> m a
throw String
func String
msg
= forall a. HasCallStack => String -> a
error (String
"Control.Concurrent.Classy.RWLock." forall a. [a] -> [a] -> [a]
++ String
func forall a. [a] -> [a] -> [a]
++ String
": " forall a. [a] -> [a] -> [a]
++ String
msg)