From 3ea4e6459333409c60f66a5745bb472d136da741 Mon Sep 17 00:00:00 2001
From: Alexander Foremny <aforemny@posteo.de>
Date: Wed, 13 Mar 2024 07:02:37 +0100
Subject: chore: add `ConduitT` instances

---
 sh.cabal             |   3 +
 src/Process/Shell.hs | 260 +++++++++++++++++++++++++++++++++++++++++++++++++--
 test/Main.hs         |  53 ++++++++++-
 3 files changed, 306 insertions(+), 10 deletions(-)

diff --git a/sh.cabal b/sh.cabal
index 2c76f73..1174200 100644
--- a/sh.cabal
+++ b/sh.cabal
@@ -17,6 +17,8 @@ library
         aeson,
         base,
         bytestring,
+        conduit,
+        conduit-extra,
         megaparsec,
         mtl,
         template-haskell,
@@ -33,6 +35,7 @@ test-suite sh-test
     build-depends:
         base,
         bytestring,
+        conduit,
         hspec,
         mtl,
         sh,
diff --git a/src/Process/Shell.hs b/src/Process/Shell.hs
index ffebad5..5b079a9 100644
--- a/src/Process/Shell.hs
+++ b/src/Process/Shell.hs
@@ -9,54 +9,279 @@
 module Process.Shell
   ( sh,
     Quotable (..),
+    Outputable (..),
+    Inputable (..),
     ExitCodeException (..),
     DecodeException (..),
   )
 where
 
+import Conduit
 import Control.Exception (Exception, throw)
 import Control.Monad
-import Control.Monad.Reader
 import Data.Aeson
 import Data.ByteString.Char8 qualified as B
 import Data.ByteString.Lazy.Char8 qualified as LB
 import Data.ByteString.Lazy.UTF8 qualified as LB
 import Data.ByteString.UTF8 qualified as B
-import Data.Functor.Identity
+import Data.Conduit.Process.Typed
+import Data.Function
 import Data.Maybe
 import Data.String
 import Data.Text qualified as T
 import Data.Text.Encoding qualified as T
 import Data.Text.Lazy qualified as LT
 import Data.Text.Lazy.Encoding qualified as LT
-import Data.Void
 import Language.Haskell.TH hiding (Type)
 import Language.Haskell.TH.Quote
-import System.Process.Typed
 import Text.Megaparsec
 import Text.Megaparsec.Char
 import Prelude hiding (exp)
 
 data Script = Script (LB.ByteString -> LB.ByteString) String
 
-class Processable m r where
-  sh_ :: m Script -> m r
+class Processable m a where
+  sh_ :: m Script -> m a
 
 instance (MonadIO m) => Processable m () where
   sh_ = ((\(Script _ s) -> liftIO (runProcess_ (fromString s))) =<<)
 
-instance (MonadIO m, Outputable a) => Processable m a where
+instance (MonadIO m, Outputable stdoutAndStderr) => Processable m stdoutAndStderr where
   sh_ = ((\(Script strip s) -> fmap (fromLBS . strip) (liftIO (readProcessInterleaved_ (fromString s)))) =<<)
 
 instance (MonadIO m, Outputable stdout, Outputable stderr) => Processable m (stdout, stderr) where
   sh_ = ((\(Script stripNL s) -> fmap (\(out, err) -> (fromLBS (stripNL out), fromLBS (stripNL err))) (liftIO (readProcess_ (fromString s)))) =<<)
 
+instance (MonadIO m, Outputable stdout) => Processable m (stdout, ()) where
+  sh_ = ((\(Script stripNL s) -> fmap (\out -> (fromLBS (stripNL out), ())) (liftIO (readProcessStdout_ (fromString s)))) =<<)
+
+instance (MonadIO m, Outputable stderr) => Processable m ((), stderr) where
+  sh_ = ((\(Script stripNL s) -> fmap (\err -> ((), fromLBS (stripNL err))) (liftIO (readProcessStderr_ (fromString s)))) =<<)
+
 instance (MonadIO m) => Processable m ExitCode where
   sh_ = ((\(Script _ s) -> liftIO (runProcess (fromString s))) =<<)
 
+instance (MonadIO m, Outputable stdoutAndStderr) => Processable m (ExitCode, stdoutAndStderr) where
+  sh_ = ((\(Script stripNL s) -> fmap (\(exitCode, outErr) -> (exitCode, fromLBS (stripNL outErr))) (liftIO (readProcessInterleaved (fromString s)))) =<<)
+
+instance (MonadIO m, Outputable stdout) => Processable m (ExitCode, stdout, ()) where
+  sh_ = ((\(Script stripNL s) -> fmap (\(exitCode, out) -> (exitCode, fromLBS (stripNL out), ())) (liftIO (readProcessStdout (fromString s)))) =<<)
+
+instance (MonadIO m, Outputable stderr) => Processable m (ExitCode, (), stderr) where
+  sh_ = ((\(Script stripNL s) -> fmap (\(exitCode, err) -> (exitCode, (), fromLBS (stripNL err))) (liftIO (readProcessStderr (fromString s)))) =<<)
+
 instance (MonadIO m, Outputable stdout, Outputable stderr) => Processable m (ExitCode, stdout, stderr) where
   sh_ = ((\(Script stripNL s) -> fmap (\(exitCode, out, err) -> (exitCode, fromLBS (stripNL out), fromLBS (stripNL err))) (liftIO (readProcess (fromString s)))) =<<)
 
+instance
+  (MonadIO m, MonadResource m, Inputable stdin, Outputable stdoutAndStderr, Monoid stdoutAndStderr) =>
+  Processable (ConduitT stdin Void m) stdoutAndStderr
+  where
+  sh_ =
+    ( ( \(Script strip s) -> do
+          let stripC = awaitForever (\i -> maybe (yield (strip i)) (\_ -> yield i) =<< peekC)
+          bracketP
+            ( startProcess
+                ( fromString s
+                    & setStdin createSinkClose
+                    & setStdout createSource
+                    & setStderr createSource
+                )
+            )
+            stopProcess
+            ( \p -> do
+                awaitForever (yield . LB.toStrict . toLBS) .| transPipe liftIO (getStdin p)
+                outErr <-
+                  ( mapOutput fromLBS $
+                      mapOutput LB.fromStrict (transPipe liftIO (getStdout p >> getStderr p))
+                        .| stripC
+                    )
+                    .| foldC
+                checkExitCode p
+                pure outErr
+            )
+      )
+        =<<
+    )
+
+instance
+  (MonadIO m, MonadResource m, Inputable stdin, Outputable stdout, Outputable stderr, Monoid stdout, Monoid stderr) =>
+  Processable (ConduitT stdin Void m) (stdout, stderr)
+  where
+  sh_ =
+    ( ( \(Script strip s) -> do
+          let stripC = awaitForever (\i -> maybe (yield (strip i)) (\_ -> yield i) =<< peekC)
+          bracketP
+            ( startProcess
+                ( fromString s
+                    & setStdin createSinkClose
+                    & setStdout createSource
+                    & setStderr createSource
+                )
+            )
+            stopProcess
+            ( \p -> do
+                awaitForever (yield . LB.toStrict . toLBS) .| transPipe liftIO (getStdin p)
+                outErr <-
+                  (,)
+                    <$> ( ( mapOutput fromLBS $
+                              mapOutput LB.fromStrict (transPipe liftIO (getStdout p))
+                                .| stripC
+                          )
+                            .| foldC
+                        )
+                    <*> ( ( mapOutput fromLBS $
+                              mapOutput LB.fromStrict (transPipe liftIO (getStderr p))
+                                .| stripC
+                          )
+                            .| foldC
+                        )
+                checkExitCode p
+                pure outErr
+            )
+      )
+        =<<
+    )
+
+instance
+  (Monad m, MonadIO m, MonadResource m, Inputable stdin) =>
+  Processable (ConduitT stdin Void m) ()
+  where
+  sh_ =
+    ( ( \(Script _ s) -> do
+          bracketP
+            ( startProcess
+                ( fromString s
+                    & setStdin createSinkClose
+                    & setStdout inherit
+                    & setStderr inherit
+                )
+            )
+            stopProcess
+            ( \p -> do
+                awaitForever (yield . LB.toStrict . toLBS) .| transPipe liftIO (getStdin p)
+                checkExitCode p
+            )
+      )
+        =<<
+    )
+
+instance
+  (MonadIO m, MonadResource m, Inputable stdin, Outputable stdout, Monoid stdout) =>
+  Processable (ConduitT stdin Void m) (stdout, ())
+  where
+  sh_ =
+    ( ( \(Script strip s) -> do
+          let stripC = awaitForever (\i -> maybe (yield (strip i)) (\_ -> yield i) =<< peekC)
+          bracketP
+            ( startProcess
+                ( fromString s
+                    & setStdin createSinkClose
+                    & setStdout createSource
+                    & setStderr inherit
+                )
+            )
+            stopProcess
+            ( \p -> do
+                awaitForever (yield . LB.toStrict . toLBS) .| transPipe liftIO (getStdin p)
+                out <-
+                  ( mapOutput fromLBS $
+                      mapOutput LB.fromStrict (transPipe liftIO (getStdout p))
+                        .| stripC
+                    )
+                    .| foldC
+                checkExitCode p
+                pure (out, ())
+            )
+      )
+        =<<
+    )
+
+instance
+  (MonadIO m, MonadResource m, Inputable stdin, Outputable stderr, Monoid stderr) =>
+  Processable (ConduitT stdin Void m) ((), stderr)
+  where
+  sh_ =
+    ( ( \(Script strip s) -> do
+          let stripC = awaitForever (\i -> maybe (yield (strip i)) (\_ -> yield i) =<< peekC)
+          bracketP
+            ( startProcess
+                ( fromString s
+                    & setStdin createSinkClose
+                    & setStdout inherit
+                    & setStderr createSource
+                )
+            )
+            stopProcess
+            ( \p -> do
+                awaitForever (yield . LB.toStrict . toLBS) .| transPipe liftIO (getStdin p)
+                err <-
+                  ( mapOutput fromLBS $
+                      mapOutput LB.fromStrict (transPipe liftIO (getStderr p))
+                        .| stripC
+                    )
+                    .| foldC
+                checkExitCode p
+                pure ((), err)
+            )
+      )
+        =<<
+    )
+
+instance
+  (MonadIO m, MonadResource m, Inputable stdin, Outputable stdoutAndStderr) =>
+  Processable (ConduitT stdin stdoutAndStderr m) ()
+  where
+  sh_ =
+    ( ( \(Script strip s) -> do
+          let stripC = awaitForever (\i -> maybe (yield (strip i)) (\_ -> yield i) =<< peekC)
+          bracketP
+            ( startProcess
+                ( fromString s
+                    & setStdin createSinkClose
+                    & setStdout createSource
+                    & setStderr createSource
+                )
+            )
+            stopProcess
+            ( \p -> do
+                awaitForever (yield . LB.toStrict . toLBS) .| transPipe liftIO (getStdin p)
+                mapOutput fromLBS $
+                  mapOutput LB.fromStrict (transPipe liftIO (getStdout p >> getStderr p))
+                    .| stripC
+                checkExitCode p
+            )
+      )
+        =<<
+    )
+
+instance
+  (MonadIO m, MonadResource m, Inputable stdin, Outputable stdout, Outputable stderr) =>
+  Processable (ConduitT stdin (Either stderr stdout) m) ()
+  where
+  sh_ =
+    ( ( \(Script strip s) -> do
+          let stripC = awaitForever (\i -> maybe (yield (strip i)) (\_ -> yield i) =<< peekC)
+          bracketP
+            ( startProcess
+                ( fromString s
+                    & setStdin createSinkClose
+                    & setStdout createSource
+                    & setStderr createSource
+                )
+            )
+            stopProcess
+            ( \p -> do
+                awaitForever (yield . LB.toStrict . toLBS) .| transPipe liftIO (getStdin p)
+                transPipe liftIO $ do
+                  mapOutput Right (mapOutput fromLBS (mapOutput LB.fromStrict (getStdout p) .| stripC))
+                  mapOutput Left (mapOutput fromLBS (mapOutput LB.fromStrict (getStderr p) .| stripC))
+                checkExitCode p
+            )
+      )
+        =<<
+    )
+
 class Outputable a where
   fromLBS :: LB.ByteString -> a
 
@@ -75,6 +300,27 @@ instance Outputable T.Text where
 instance Outputable LT.Text where
   fromLBS = LT.decodeUtf8
 
+class Inputable a where
+  toLBS :: a -> LB.ByteString
+
+instance Inputable () where
+  toLBS _ = LB.pack ""
+
+instance Inputable String where
+  toLBS = LB.fromString
+
+instance Inputable B.ByteString where
+  toLBS = LB.fromStrict
+
+instance Inputable LB.ByteString where
+  toLBS = id
+
+instance Inputable T.Text where
+  toLBS = LT.encodeUtf8 . LT.fromStrict
+
+instance Inputable LT.Text where
+  toLBS = LT.encodeUtf8
+
 data DecodeException = DecodeException String
   deriving (Show)
 
diff --git a/test/Main.hs b/test/Main.hs
index 2f20997..f6bfaf6 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -5,6 +5,7 @@
 
 module Main (main) where
 
+import Conduit
 import Data.ByteString.Char8 qualified as B
 import Data.ByteString.Lazy.Char8 qualified as LB
 import Data.Text qualified as T
@@ -14,25 +15,67 @@ import Test.Hspec
 
 main :: IO ()
 main = hspec do
+  describe "input" do
+    it "capture stdin" do
+      (`shouldBe` "stdin") =<< runConduitRes (yield "stdin\n" .| [sh|cat|])
+      (`shouldBe` "stdin") =<< runConduitRes (yield "stdin\n" .| [sh|cat|] .| foldC)
   describe "output" do
     it "capture stdout" do
-      (`shouldBe` "stdout") . fst @String @String =<< [sh|echo stdout|]
+      (`shouldBe` "stdout") . fst @String @() =<< [sh|echo stdout|]
+      (`shouldBe` "stdout") . fst @String @String =<< runConduitRes [sh|echo stdout|]
+      (`shouldBe` "stdout") . snd @String @String =<< runConduitRes ([sh|echo stdout|] .| partitionEithersC)
+    -- TODO conduit
     it "capture stderr" do
-      (`shouldBe` "stderr") . snd @String =<< [sh|>&2 echo stderr|]
+      (`shouldBe` "stderr") . snd @() @String =<< [sh|>&2 echo stderr|]
+      (`shouldBe` "stderr") . snd @() @String =<< runConduitRes [sh|>&2 echo stderr|]
+      (`shouldBe` "stderr") . fst @String @String
+        =<< runConduitRes ([sh|>&2 echo stderr|] .| partitionEithersC)
     it "capture stdout and stderr" do
       (`shouldBe` ("stdout", "stderr"))
         =<< [sh|
         echo stdout
         >&2 echo stderr
         |]
+      (`shouldBe` ("stdout", "stderr"))
+        =<< runConduitRes
+          ( [sh|
+              echo stdout
+              >&2 echo stderr
+              |]
+          )
+      (`shouldBe` ("stderr", "stdout"))
+        =<< runConduitRes
+          ( [sh|
+              echo stdout
+              >&2 echo stderr
+              |]
+              .| partitionEithersC
+          )
     it "capture stdout and stderr interleaved" do
       (`shouldBe` "stdout\nstderr")
         =<< [sh|
         echo stdout
         >&2 echo stderr
         |]
+      (`shouldBe` "stdout\nstderr")
+        =<< runConduitRes
+          ( [sh|
+              echo stdout
+              >&2 echo stderr
+              |]
+          )
+      (`shouldBe` "stdout\nstderr")
+        =<< runConduitRes
+          ( [sh|
+              echo stdout
+              >&2 echo stderr
+              |]
+              .| foldC
+          )
     it "preserve trailing newline" do
-      (`shouldBe` "stdout\n") . fst @String @String =<< [sh|echo stdout \|]
+      (`shouldBe` "stdout\n") =<< [sh|echo stdout \|]
+      (`shouldBe` "stdout\n") =<< runConduitRes [sh|echo stdout \|]
+      (`shouldBe` "stdout\n") =<< runConduitRes ([sh|echo stdout \|] .| foldC)
   describe "arguments" do
     it "passes `Int`" do
       (`shouldBe` "1") =<< let x = 1 :: Int in [sh|echo '#{x}'|]
@@ -64,3 +107,7 @@ main = hspec do
       (`shouldBe` "foobar") =<< [sh|echo 'foobar'|]
     it "parses double cross" do
       (`shouldBe` "0") =<< [sh|echo $#|]
+
+partitionEithersC :: (Monad m, Monoid a, Monoid b) => ConduitT (Either a b) o m (a, b)
+partitionEithersC =
+  foldlC (\(es, rs) x -> either (\e -> (e `mappend` es, rs)) (\r -> (es, r `mappend` rs)) x) (mempty, mempty)
-- 
cgit v1.2.3