module Comment ( Comment (..), Point (..), getComments, extractComments, CommentStyle (..), uncomment, comment, ) where import Comment.Language import Control.Applicative (liftA2) import Control.Exception (catch) import Control.Monad import Data.Binary (Binary) import Data.ByteString qualified as B import Data.ByteString.Lazy qualified as LB import Data.List (find, sortBy) import Data.List.NonEmpty qualified as N import Data.Maybe (fromMaybe) import Data.Ord (comparing) import Data.Text qualified as T import Data.Text.Encoding qualified as T import Exception qualified as E import Foreign.C.String import Foreign.Marshal.Alloc (free, malloc) import Foreign.Marshal.Array (mallocArray, peekArray) import Foreign.Ptr (nullPtr) import Foreign.Storable import GHC.Generics (Generic) import GHC.Int (Int64) import Git qualified import System.FilePath (takeExtension) import TreeSitter.Node qualified as S import TreeSitter.Parser qualified as S import TreeSitter.Tree qualified as S data Comment = Comment { text :: T.Text, language :: Language, startByte :: Int64, endByte :: Int64, startPoint :: Point, endPoint :: Point, filePath :: FilePath } deriving (Eq, Show) data Point = Point { row :: Int, column :: Int } deriving (Eq, Show, Generic, Binary) getComments :: Git.CommitHash -> FilePath -> IO [Comment] getComments commitHash filePath = fmap mergeLineComments . (extractComments filePath language . LB.toStrict) =<< catch (Git.readTextFileOfBS commitHash filePath) (\(_ :: E.CannotReadFile) -> pure "") where language = fromExtension (takeExtension filePath) mergeLineComments :: [Comment] -> [Comment] mergeLineComments = map mergeGroup . chainsBy (\a b -> a.endPoint.row + 1 == b.startPoint.row) . sortBy (comparing (liftA2 (,) (.startByte) (.endByte))) mergeGroup :: N.NonEmpty Comment -> Comment mergeGroup css@(c N.:| cs) = c { text = T.unlines (map (.text) (c : cs)), startByte = first.startByte, endByte = last.endByte, startPoint = first.startPoint, endPoint = last.endPoint } where first = N.head css last = N.last css {- A version of `Data.List.groupBy` that uses the last added group-member for comparison with new candidates for the group. `Data.List.groupBy` uses the initial member for all subsequent comparisons. -} chainsBy :: (a -> a -> Bool) -> [a] -> [N.NonEmpty a] chainsBy p = reverse . map N.reverse . go [] where go rs [] = rs go [] (x : xs) = go [N.singleton x] xs go (ass@((a N.:| as) : rs)) (x : xs) | p a x = go ((x N.:| a : as) : rs) xs | otherwise = go (N.singleton x : ass) xs extractComments :: FilePath -> Language -> B.ByteString -> IO [Comment] extractComments filePath language str' = do S.withParser (parser language) $ \parser -> do B.useAsCStringLen str' $ \(str, len) -> do tree <- S.ts_parser_parse_string parser nullPtr str len S.withRootNode tree $ \node -> do map ( \n' -> let startByte = fromIntegral $ S.nodeStartByte n' endByte = fromIntegral $ S.nodeEndByte n' text = T.decodeUtf8 . B.take (fromIntegral endByte - fromIntegral startByte) . B.drop (fromIntegral startByte) $ str' startPoint = fromTSPoint (S.nodeStartPoint n') endPoint = fromTSPoint (S.nodeEndPoint n') fromTSPoint (S.TSPoint {..}) = Point (fromIntegral pointRow + 1) (fromIntegral pointColumn + 1) in Comment {..} ) <$> (commentsFromNodeRec language =<< peek node) commentsFromNodeRec :: Language -> S.Node -> IO [S.Node] commentsFromNodeRec language = (filterM (isCommentNode language) =<<) . childNodesFromNodeRec isCommentNode :: Language -> S.Node -> IO Bool isCommentNode language n = (`elem` (nodeTypes language)) <$> peekCString (S.nodeType n) childNodesFromNodeRec :: S.Node -> IO [S.Node] childNodesFromNodeRec n = do ns' <- childNodesFromNode n ns <- concat <$> mapM childNodesFromNodeRec ns' pure $ n : ns childNodesFromNode :: S.Node -> IO [S.Node] childNodesFromNode n = do let numChildren = fromIntegral (S.nodeChildCount n) ns <- mallocArray numChildren tsNode <- malloc poke tsNode (S.nodeTSNode n) S.ts_node_copy_child_nodes tsNode ns free tsNode ns' <- peekArray numChildren ns free ns pure ns' data CommentStyle = LineStyle T.Text | BlockStyle T.Text T.Text deriving (Eq, Show, Generic, Binary) comment :: CommentStyle -> T.Text -> T.Text comment (LineStyle linePrefix) = T.unlines . map ((linePrefix <> " ") <>) . T.lines comment (BlockStyle blockStart blockEnd) = (blockStart <>) . (<> blockEnd) uncomment :: Language -> T.Text -> (CommentStyle, T.Text) uncomment language rawText = maybe ( ( LineStyle (lineStart language), stripLineComments (lineStart language) text ) ) ( \(blockStart, blockEnd) -> ( BlockStyle blockStart blockEnd, stripBlockComment blockStart blockEnd text ) ) $ do (blockStarts, blockEnd) <- block language (,blockEnd) <$> find (`T.isPrefixOf` text) blockStarts where text = stripLines rawText stripLines = T.intercalate "\n" . map T.strip . T.lines stripLineComments :: T.Text -> T.Text -> T.Text stripLineComments lineStart text = onLines ( \line -> fromMaybe line . fmap T.stripStart $ T.stripPrefix lineStart line ) text where onLines f = T.intercalate "\n" . map f . T.lines stripBlockComment :: T.Text -> T.Text -> T.Text -> T.Text stripBlockComment blockStart blockEnd text = T.strip . (fromMaybe text . T.stripSuffix blockEnd) . (fromMaybe text . T.stripPrefix blockStart) $ text