aboutsummaryrefslogtreecommitdiffstats
path: root/app/History/Plan.hs
diff options
context:
space:
mode:
Diffstat (limited to 'app/History/Plan.hs')
-rw-r--r--app/History/Plan.hs207
1 files changed, 207 insertions, 0 deletions
diff --git a/app/History/Plan.hs b/app/History/Plan.hs
new file mode 100644
index 0000000..3dec391
--- /dev/null
+++ b/app/History/Plan.hs
@@ -0,0 +1,207 @@
+module History.Plan
+ ( Planable (..),
+ formulate,
+ realise,
+ )
+where
+
+import Control.Concurrent (forkIO, killThread)
+import Control.Concurrent.STM (TQueue, TVar, atomically, newTQueueIO, newTVarIO, readTQueue, readTVar, retry, writeTQueue, writeTVar)
+import Control.Monad (forever)
+import Data.Kind (Type)
+import Data.List (intercalate)
+import Data.List.NonEmpty qualified as NE
+import Data.Map qualified as M
+import Data.Proxy (Proxy (Proxy))
+import Data.Time.Clock
+import Parallel (parMapM_)
+import System.IO (hFlush, hPutStr, stderr, stdout)
+import System.Time.Monotonic (Clock, clockGetTime, newClock)
+import Text.Printf (printf)
+
+-- | Left-associative monadic fold of a structure with automatic parallelism, status-reporting and intermittent results (cacelability).
+--
+-- Semantically, `realise (formulate xs)` is equivalent to
+--
+-- ```haskell
+-- \(x:xs) -> foldM propagate (assume <$> protoOf x) =<< mapM protoOf xs
+-- ```
+--
+-- However, this module provides the following features:
+--
+-- - `protoOf` and `propagate` are computed in parallel,
+-- - `propagate` is scheduled to terminate as soon as possible,
+--
+-- This makes it possible to scale the computation across multiple cores, but still interrupt the process for intermittent results.
+--
+-- Additionally, a progress report is presented on stdout.
+class Planable output where
+ type Id output :: Type
+ type Proto output :: Type
+ protoOf :: Proxy output -> Id output -> IO (Proto output)
+ assume :: Proto output -> output
+ propagate :: [Id output] -> output -> Proto output -> output
+
+data Plan output = Plan (NE.NonEmpty (Id output)) [Task output]
+
+data Task output
+ = Compute (Id output)
+ | Propagate (Maybe (Id output)) (Id output)
+
+isCompute, isPropagate :: Task output -> Bool
+isCompute (Compute _) = True
+isCompute _ = False
+isPropagate (Propagate _ _) = True
+isPropagate _ = False
+
+formulate ::
+ Planable output =>
+ Proxy output ->
+ NE.NonEmpty (Id output) ->
+ Plan output
+formulate _ ids@(NE.uncons -> (id, (maybe [] NE.toList -> rest))) =
+ Plan ids $
+ Compute id
+ : Propagate Nothing id
+ : concat
+ ( zipWith
+ ( \id id' ->
+ [ Compute id,
+ Propagate id' id
+ ]
+ )
+ rest
+ (map Just (id : rest))
+ )
+
+data State output = State
+ { tasks :: [Task output],
+ protos :: M.Map (Id output) (Proto output),
+ outputs :: M.Map (Id output) output,
+ elapsed :: Clock
+ }
+
+realise :: (Ord (Id output), Planable output) => Plan output -> IO output
+realise plan@(Plan ids tasks) = do
+ elapsed <- newClock
+ let state0 =
+ State
+ { tasks = tasks,
+ protos = M.empty,
+ outputs = M.empty,
+ elapsed = elapsed
+ }
+ stateT <- newTVarIO state0
+ statusT <- newTQueueIO
+ tid <- forkIO $ forever do
+ state <- atomically (readTQueue statusT)
+ hPutStr stderr . printStatus =<< status state0 state
+ hFlush stdout
+ parMapM_ (step Proxy plan statusT stateT) tasks
+ let id = NE.last ids
+ output <- atomically $ do
+ maybe retry pure . M.lookup id . (.outputs) =<< readTVar stateT
+ killThread tid
+ pure output
+
+step ::
+ (Ord (Id output), Planable output) =>
+ Proxy output ->
+ Plan output ->
+ TQueue (State output) ->
+ TVar (State output) ->
+ Task output ->
+ IO ()
+step proxy _ statusT stateT (Compute id) = do
+ proto <- protoOf proxy id
+ atomically do
+ state <- readTVar stateT
+ let state' = state {protos = M.insert id proto state.protos}
+ writeTVar stateT state'
+ writeTQueue statusT state'
+ pure ()
+step _ (Plan ids _) statusT stateT (Propagate (Just id') id) = do
+ (output, proto) <- atomically do
+ state <- readTVar stateT
+ output <- maybe retry pure (M.lookup id' state.outputs)
+ proto <- maybe retry pure (M.lookup id state.protos)
+ pure (output, proto)
+ let output' = propagate (NE.toList ids) output proto
+ atomically do
+ state <- readTVar stateT
+ let state' = state {outputs = M.insert id output' state.outputs}
+ writeTVar stateT state'
+ writeTQueue statusT state'
+step _ _ statusT stateT (Propagate Nothing id) = do
+ proto <- atomically do
+ state <- readTVar stateT
+ maybe retry pure (M.lookup id state.protos)
+ let output = assume proto
+ atomically do
+ state <- readTVar stateT
+ let state' = state {outputs = M.insert id output state.outputs}
+ writeTVar stateT state'
+ writeTQueue statusT state'
+
+data Status = Status
+ { numTasks :: Progress Int,
+ numProtos :: Progress Int,
+ numOutputs :: Progress Int,
+ elapsed :: DiffTime
+ }
+ deriving (Show, Eq)
+
+data Progress a = Progress
+ { total :: a,
+ completed :: a
+ }
+ deriving (Show, Eq)
+
+status :: State output -> State output -> IO Status
+status state0 state = do
+ let totalTasks = length state0.tasks
+ totalProtos = length (filter isCompute state0.tasks)
+ totalOutputs = length (filter isPropagate state0.tasks)
+ completedTasks = completedProtos + completedOutputs
+ completedProtos = M.size state.protos
+ completedOutputs = M.size state.outputs
+ elapsed <- clockGetTime state.elapsed
+ pure
+ Status
+ { numTasks = Progress totalTasks completedTasks,
+ numProtos = Progress totalProtos completedProtos,
+ numOutputs = Progress totalOutputs completedOutputs,
+ elapsed = elapsed
+ }
+
+printStatus :: Status -> String
+printStatus Status {..} = do
+ let formatProgress completed total = pad total completed <> "/" <> total
+ pad total completed = replicate (length total - length completed) ' ' <> completed
+ eta =
+ formatEta
+ ( (fromIntegral numTasks.total * (realToFrac elapsed))
+ / fromIntegral numTasks.completed
+ )
+ (<> "\r") . intercalate " " $
+ [ formatProgress (show numTasks.completed) (show numTasks.total),
+ formatProgress (show numProtos.completed) (show numProtos.total),
+ formatProgress (show numOutputs.completed) (show numOutputs.total),
+ "ETA " <> eta
+ ]
+
+formatEta :: Double -> String
+formatEta s = do
+ if s < 60
+ then printf "%.1fs" s
+ else do
+ let m = s / 60
+ if m < 60
+ then printf "%.1fm" m
+ else do
+ let h = m / 60
+ if h < 24
+ then printf "%.1hf" h
+ else do
+ let d = h / 24
+ printf "%.1fd" d