From 9abe1e72043062bc35a243e1c2f7027fde42b814 Mon Sep 17 00:00:00 2001
From: Alexander Foremny <aforemny@posteo.de>
Date: Wed, 7 Feb 2024 02:32:13 +0100
Subject: support joins

---
 app/Main.hs | 150 +++++++++++++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 128 insertions(+), 22 deletions(-)

(limited to 'app/Main.hs')

diff --git a/app/Main.hs b/app/Main.hs
index d650a33..c0e6d11 100644
--- a/app/Main.hs
+++ b/app/Main.hs
@@ -1,4 +1,5 @@
 {-# LANGUAGE OverloadedStrings #-}
+{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
 
 module Main where
 
@@ -23,15 +24,16 @@ import Text.Printf (printf)
 debug :: Show a => String -> a -> a
 debug s x = trace (printf "%s: %s" s (show x)) x
 
+main :: IO ()
 main = do
   setCurrentDirectory "./data"
-  -- query "SELECT . FROM c"
+  putStrLn "> SELECT . FROM c"
   query' $ Select [Unqualified "."] "c" [] []
-  putStrLn ""
-  -- query "SELECT id FROM c"
-  query' $ Select [Qualified "c" "id"] "c" [] []
-  putStrLn ""
-  -- query "SELECT c.id, j.id, is_j FROM c JOIN j WHERE j.id == c.j_id"
+
+  putStrLn "\n> SELECT id FROM c"
+  query' $ Select [Unqualified "id"] "c" [] []
+
+  putStrLn "\n> SELECT c.id, j.id, is_j FROM c LEFT JOIN j ON j.id == c.j_id"
   query' $
     Select
       [ Qualified "c" "id",
@@ -39,12 +41,43 @@ main = do
         Unqualified "is_j"
       ]
       "c"
-      ["j"]
-      [ Eq (Qualified "j" "id") (Qualified "c" "j_id")
+      [ LeftJoin
+          "j"
+          [ Eq (Qualified "j" "id") (Qualified "c" "j_id")
+          ]
+      ]
+      []
+
+  putStrLn "\n> SELECT c.id, j.id FROM c RIGHT JOIN j ON j.id == c.j_id"
+  query' $
+    Select
+      [ Qualified "c" "id",
+        Qualified "j" "id"
+      ]
+      "c"
+      [ RightJoin
+          "j"
+          [ Eq (Qualified "j" "id") (Qualified "c" "j_id")
+          ]
       ]
+      []
+
+  putStrLn "\n> SELECT c.id, j.id FROM c FULL JOIN j ON j.id == c.j_id"
+  query' $
+    Select
+      [ Qualified "c" "id",
+        Qualified "j" "id"
+      ]
+      "c"
+      [ FullJoin
+          "j"
+          [ Eq (Qualified "j" "id") (Qualified "c" "j_id")
+          ]
+      ]
+      []
 
 data Query
-  = Select [Field] Collection Join Where
+  = Select [Field] Collection [Join FilePath] Where
   deriving (Show)
 
 data Field
@@ -54,7 +87,11 @@ data Field
 
 type Collection = FilePath
 
-type Join = [FilePath]
+data Join a
+  = LeftJoin a Where
+  | RightJoin a Where
+  | FullJoin a Where
+  deriving (Show)
 
 type Where = [Cmp]
 
@@ -74,13 +111,60 @@ instance IsString Query where
 query :: Query -> IO [J.Value]
 query (Select fs c js ws) = do
   c' <- mapM (fmap (Record c) . decodeFile . (c </>)) =<< ls c
-  js' <- mapM (\j -> mapM (fmap (Record j) . decodeFile . (j </>)) =<< ls j) js
+  js' <-
+    mapM
+      ( \j ->
+          case j of
+            LeftJoin c ws ->
+              fmap (\j' -> LeftJoin (map (Record c) j') ws) . mapM (decodeFile . (c </>)) =<< ls c
+            RightJoin c ws ->
+              fmap (\j' -> RightJoin (map (Record c) j') ws) . mapM (decodeFile . (c </>)) =<< ls c
+            FullJoin c ws ->
+              fmap (\j' -> FullJoin (map (Record c) j') ws) . mapM (decodeFile . (c </>)) =<< ls c
+      )
+      js
   pure $ map (select fs) $ where_ ws $ combine c' js'
 
-combine c = combine' (map (: []) c)
+combine :: [Record J.Value] -> [Join [Record J.Value]] -> [[Record J.Value]]
+combine vs js = combine' (map (: []) vs) js
   where
-    combine' cs [] = cs
-    combine' cs (js : jss) = combine' [c ++ [j] | c <- cs, j <- js] jss
+    combine' vss [] = vss
+    combine' vss (LeftJoin js ws : jss) =
+      combine'
+        ( concatMap
+            ( \vs -> case filter (satisfies ws) $ map (\j -> vs ++ [j]) js of
+                [] -> [vs]
+                vs' -> vs'
+            )
+            vss
+        )
+        jss
+    combine' vss (RightJoin js ws : jss) =
+      combine'
+        ( concatMap
+            ( \j -> case filter (satisfies ws) $ map (\vs -> vs ++ [j]) vss of
+                [] -> [[j]]
+                vs' -> vs'
+            )
+            js
+        )
+        jss
+    combine' vss (FullJoin js ws : jss) =
+      combine'
+        ( concatMap
+            ( \vs -> case filter (satisfies ws) $ map (\j -> vs ++ [j]) js of
+                [] -> [vs]
+                vs' -> vs'
+            )
+            vss
+            ++ concatMap
+              ( \j -> case filter (satisfies ws) $ map (\vs -> vs ++ [j]) vss of
+                  [] -> [[j]]
+                  _ -> []
+              )
+              js
+        )
+        jss
 
 ls :: FilePath -> IO [FilePath]
 ls =
@@ -96,7 +180,18 @@ decodeFile =
   fmap (fromMaybe (throw DecodeException)) . J.decodeFileStrict
 
 select :: [Field] -> [Record J.Value] -> J.Value
-select fs vs = join' $ map (select' fs) vs
+select fs vs =
+  mergeUnsafe (join' (map ((\(Record _ v) -> v) . select' fs) vs)) v0
+  where
+    v0 =
+      joinUnsafe $
+        mapMaybe
+          ( \f -> case f of
+              Qualified c k -> Just $ J.Object $ JM.singleton (JK.fromText (T.pack c <> "." <> k)) J.Null
+              Unqualified "." -> Nothing
+              Unqualified k -> Just $ J.Object $ JM.singleton (JK.fromText k) J.Null
+          )
+          fs
 
 select' :: [Field] -> Record J.Value -> Record J.Value
 select' [Unqualified "."] v = v
@@ -114,13 +209,20 @@ matches :: Record T.Text -> Field -> Bool
 matches (Record c k) (Qualified c' k') = c == c' && k == k'
 matches (Record _ k) (Unqualified k') = k == k'
 
-join' :: [Record J.Value] -> J.Value
-join' vs = foldl' merge (J.Object JM.empty) (map (\(Record _ v) -> v) vs)
+join' :: [J.Value] -> J.Value
+join' = foldl' merge (J.Object JM.empty)
+
+joinUnsafe :: [J.Value] -> J.Value
+joinUnsafe = foldl' mergeUnsafe (J.Object JM.empty)
 
 where_ :: Where -> [[Record J.Value]] -> [[Record J.Value]]
-where_ ws vss = filter (\vs -> all (\w -> satisfy w vs) ws) vss
-  where
-    satisfy (Eq f f') vs = unique f vs == unique f' vs
+where_ ws vss = filter (satisfies ws) vss
+
+satisfies :: [Cmp] -> [Record J.Value] -> Bool
+satisfies ws vs = all (\w -> satisfy w vs) ws
+
+satisfy :: Cmp -> [Record J.Value] -> Bool
+satisfy (Eq f f') vs = unique f vs == unique f' vs
 
 data DuplicateField' = DuplicateField'
   deriving (Show)
@@ -144,10 +246,14 @@ data DuplicateField = DuplicateField
 
 instance Exception DuplicateField
 
+mergeUnsafe :: J.Value -> J.Value -> J.Value
+mergeUnsafe (J.Object kvs) (J.Object kvs') =
+  J.Object (JM.union kvs kvs')
+
 merge :: J.Value -> J.Value -> J.Value
-merge (J.Object kvs) (J.Object kvs') =
+merge v@(J.Object kvs) v'@(J.Object kvs') =
   case disjoint kvs kvs' of
-    True -> J.Object (JM.union kvs kvs')
+    True -> mergeUnsafe v v'
     False -> throw DuplicateField
 
 disjoint :: JM.KeyMap v -> JM.KeyMap v -> Bool
-- 
cgit v1.2.3