diff options
Diffstat (limited to 'app')
-rw-r--r-- | app/Comment.hs | 114 | ||||
-rw-r--r-- | app/Issue.hs | 7 | ||||
-rw-r--r-- | app/TreeSitter.hs | 87 | ||||
-rw-r--r-- | app/TreeSitter/bridge.c | 34 | ||||
-rw-r--r-- | app/tree_sitter.c | 85 |
5 files changed, 169 insertions, 158 deletions
diff --git a/app/Comment.hs b/app/Comment.hs index 63f610a..2769c83 100644 --- a/app/Comment.hs +++ b/app/Comment.hs @@ -12,7 +12,6 @@ where import Comment.Language import Control.Applicative (liftA2) import Control.Exception (catch) -import Control.Monad import Data.Binary (Binary) import Data.ByteString qualified as B import Data.ByteString.Lazy qualified as LB @@ -23,12 +22,10 @@ import Data.Ord (comparing) import Data.Text qualified as T import Data.Text.Encoding qualified as T import Exception qualified as E -import Foreign.C.String -import Foreign.Marshal.Alloc (free, malloc) -import Foreign.Ptr (Ptr, nullPtr) +import Foreign.Marshal.Alloc (alloca, free) +import Foreign.Marshal.Array (peekArray) import Foreign.Storable import GHC.Generics (Generic) -import GHC.Int (Int64) import Git qualified import System.FilePath (takeExtension) import TreeSitter qualified as S @@ -36,8 +33,8 @@ import TreeSitter qualified as S data Comment = Comment { text :: T.Text, language :: Language, - startByte :: Int64, - endByte :: Int64, + startByte :: Int, + endByte :: Int, startPoint :: Point, endPoint :: Point, filePath :: FilePath @@ -90,75 +87,44 @@ getComments commitHash filePath = | otherwise = go (N.singleton x : ass) xs extractComments :: FilePath -> Language -> B.ByteString -> IO [Comment] -extractComments filePath language str' = do - S.withParser (parser language) $ \parser -> do - B.useAsCStringLen str' $ \(str, len) -> do - tree <- S.ts_parser_parse_string parser nullPtr str (fromIntegral len) - node <- malloc - S.ts_tree_root_node tree node - x <- - mapM - ( \n' -> do - startByte <- fromIntegral <$> S.ts_node_start_byte n' - endByte <- fromIntegral <$> S.ts_node_end_byte n' - let text = - T.decodeUtf8 - . B.take (fromIntegral endByte - fromIntegral startByte) - . B.drop (fromIntegral startByte) - $ str' - - startPoint <- do - point <- malloc - S.ts_node_start_point n' point - S.Point {..} <- peek point - free point - pure - Point - { row = fromIntegral row, - column = fromIntegral column - } - endPoint <- do - point <- malloc - S.ts_node_end_point n' point - S.Point {..} <- peek point - free point - pure - Point - { row = fromIntegral row, - column = fromIntegral column +extractComments filePath language str' = + alloca $ \nodesPtrPtr -> do + alloca $ \numNodesPtr -> do + B.useAsCString str' $ \str -> + S.extract_comments + (parser language) + str + nodesPtrPtr + numNodesPtr + numNodes <- peek numNodesPtr + nodesPtr <- peek nodesPtrPtr + nodes <- peekArray (fromIntegral numNodes) nodesPtr + free nodesPtr + pure $ + map + ( \node -> + let startByte = fromIntegral node.startByte + endByte = fromIntegral node.endByte + in Comment + { startPoint = + Point + { row = fromIntegral node.startPoint.row + 1, + column = fromIntegral node.startPoint.column + 1 + }, + endPoint = + Point + { row = fromIntegral node.endPoint.row + 1, + column = fromIntegral node.endPoint.column + 1 + }, + text = + T.decodeUtf8 + . B.take (endByte - startByte) + . B.drop startByte + $ str', + .. } - - pure Comment {..} ) - =<< (commentsFromNodeRec language node) - free node - pure x - -commentsFromNodeRec :: Language -> Ptr S.Node -> IO [Ptr S.Node] -commentsFromNodeRec language = - (filterM (isCommentNode language) =<<) - . childNodesFromNodeRec - -isCommentNode :: Language -> Ptr S.Node -> IO Bool -isCommentNode language n = - fmap (`elem` (nodeTypes language)) . peekCString =<< S.ts_node_type n - -childNodesFromNodeRec :: Ptr S.Node -> IO [Ptr S.Node] -childNodesFromNodeRec n = do - ns' <- childNodesFromNode n - ns <- concat <$> mapM childNodesFromNodeRec ns' - pure $ n : ns - -childNodesFromNode :: Ptr S.Node -> IO [Ptr S.Node] -childNodesFromNode n = do - numChildren <- fromIntegral <$> S.ts_node_named_child_count n - mapM - ( \k -> do - node <- malloc - S.ts_node_named_child n k node - pure node - ) - [0 .. numChildren - 1] + nodes data CommentStyle = LineStyle T.Text diff --git a/app/Issue.hs b/app/Issue.hs index d58d14d..2743f45 100644 --- a/app/Issue.hs +++ b/app/Issue.hs @@ -23,7 +23,6 @@ import Data.Text.Encoding qualified as T import Data.Text.IO qualified as T import Data.Time.Clock (UTCTime (utctDay)) import GHC.Generics (Generic) -import GHC.Int (Int64) import GHC.Records (HasField (..)) import Git (Author (..), Commit (..)) import Git qualified as Git @@ -40,8 +39,8 @@ data Issue = Issue title :: T.Text, file :: FilePath, provenance :: Provenance, - startByte :: Int64, - endByte :: Int64, + startByte :: Int, + endByte :: Int, startPoint :: G.Point, endPoint :: G.Point, tags :: [Tag], @@ -100,7 +99,7 @@ instance HasField "id" Issue T.Text where getText :: Issue -> IO T.Text getText (Issue {..}) = - T.decodeUtf8 . LB.toStrict . LB.take (endByte - startByte) . LB.drop startByte + T.decodeUtf8 . LB.toStrict . LB.take (fromIntegral (endByte - startByte)) . LB.drop (fromIntegral startByte) <$> Git.readTextFileOfBS commitHash file replaceText :: Issue -> T.Text -> IO () diff --git a/app/TreeSitter.hs b/app/TreeSitter.hs index e911d1b..230fefc 100644 --- a/app/TreeSitter.hs +++ b/app/TreeSitter.hs @@ -1,65 +1,60 @@ -module TreeSitter where - --- | References: [tree-sitter/api.h](https://github.com/tree-sitter/tree-sitter/blob/master/lib/include/tree_sitter/api.h) +module TreeSitter + ( Language, + Node (..), + Point (..), + extract_comments, + tree_sitter_haskell, + ) +where import Foreign.C.String (CString) import Foreign.C.Types (CInt (..)) import Foreign.Ptr (Ptr) -import Foreign.Storable (Storable (..), peek) - -data Parser +import Foreign.Storable (Storable (..)) data Language -data Tree = Tree - data Node = Node + { startPoint :: Point, + endPoint :: Point, + startByte :: CInt, + endByte :: CInt + } + deriving (Show, Eq) instance Storable Node where - sizeOf _ = 32 + alignment _ = 8 + sizeOf _ = 24 + peek ptr = + Node + <$> peekByteOff ptr 0 + <*> peekByteOff ptr 8 + <*> peekByteOff ptr 16 + <*> peekByteOff ptr 20 + poke ptr (Node {..}) = do + pokeByteOff ptr 0 startPoint + pokeByteOff ptr 8 endPoint + pokeByteOff ptr 16 startByte + pokeByteOff ptr 20 endByte data Point = Point { row :: CInt, column :: CInt - } deriving (Show) + } + deriving (Show, Eq) instance Storable Point where + alignment _ = 4 sizeOf _ = 8 - alignment _ = 8 - peek p = Point <$> peekByteOff p 0 <*> peekByteOff p 4 - -withParser :: Ptr Language -> (Ptr Parser -> IO a) -> IO a -withParser l f = do - p <- ts_parser_new - ts_parser_set_language p l - x <- f p - ts_parser_delete p - pure x - -foreign import ccall unsafe "ts_node_start_point_p" ts_node_start_point :: Ptr Node -> Ptr Point -> IO () - -foreign import ccall unsafe "ts_node_end_point_p" ts_node_end_point :: Ptr Node -> Ptr Point -> IO () - -foreign import ccall unsafe "ts_node_start_byte_p" ts_node_start_byte :: Ptr Node -> IO CInt - -foreign import ccall unsafe "ts_node_end_byte_p" ts_node_end_byte :: Ptr Node -> IO CInt - -foreign import ccall unsafe "ts_node_type_p" ts_node_type :: Ptr Node -> IO CString - -foreign import ccall unsafe "ts_node_named_child_p" ts_node_named_child :: Ptr Node -> CInt -> Ptr Node -> IO () - -foreign import ccall unsafe "ts_node_named_child_count_p" ts_node_named_child_count :: Ptr Node -> IO CInt - -foreign import ccall unsafe "ts_tree_root_node_p" ts_tree_root_node :: Ptr Tree -> Ptr Node -> IO () - -foreign import ccall unsafe "ts_tree_delete" ts_tree_delete :: Ptr Tree -> IO () - -foreign import ccall unsafe "ts_parser_parse_string" ts_parser_parse_string :: Ptr Parser -> Ptr Tree -> CString -> CInt -> IO (Ptr Tree) - -foreign import ccall unsafe "ts_parser_new" ts_parser_new :: IO (Ptr Parser) - -foreign import ccall unsafe "ts_parser_delete" ts_parser_delete :: Ptr Parser -> IO () - -foreign import ccall unsafe "ts_parser_set_language" ts_parser_set_language :: Ptr Parser -> Ptr Language -> IO () + peek ptr = Point <$> peekByteOff ptr 0 <*> peekByteOff ptr 4 + poke ptr (Point {..}) = pokeByteOff ptr 0 row >> pokeByteOff ptr 4 column + +foreign import ccall unsafe "extract_comments" + extract_comments :: + Ptr Language -> + CString -> + Ptr (Ptr Node) -> + Ptr CInt -> + IO () foreign import ccall unsafe "tree_sitter_haskell" tree_sitter_haskell :: Ptr Language diff --git a/app/TreeSitter/bridge.c b/app/TreeSitter/bridge.c deleted file mode 100644 index 904c88e..0000000 --- a/app/TreeSitter/bridge.c +++ /dev/null @@ -1,34 +0,0 @@ -#include "tree_sitter/api.h" -#include "string.h" - -void ts_tree_root_node_p(TSTree *tree, TSNode *node) { - (*node) = ts_tree_root_node(tree); -} - -uint32_t ts_node_named_child_count_p(TSNode *node) { - return ts_node_named_child_count(*node); -} - -uint32_t ts_node_start_byte_p(TSNode *node) { - return ts_node_start_byte(*node); -} - -uint32_t ts_node_end_byte_p(TSNode *node) { - return ts_node_end_byte(*node); -} - -uint32_t ts_node_start_point_p(TSNode *node, TSPoint *point) { - (*point) = ts_node_start_point(*node); -} - -uint32_t ts_node_end_point_p(TSNode *node, TSPoint *point) { - (*point) = ts_node_end_point(*node); -} - -const char* ts_node_type_p(TSNode *node) { - return ts_node_type(*node); -} - -void ts_node_named_child_p(TSNode* self, uint32_t child_index, TSNode* node) { - (*node) = ts_node_named_child(*self, child_index); -} diff --git a/app/tree_sitter.c b/app/tree_sitter.c new file mode 100644 index 0000000..d0f9fa8 --- /dev/null +++ b/app/tree_sitter.c @@ -0,0 +1,85 @@ +#include "string.h" +#include "tree_sitter/api.h" + +typedef struct Node { + TSPoint start_point; + TSPoint end_point; + uint32_t start_byte; + uint32_t end_byte; +} Node; + +void extract_comments( + TSLanguage* language, + char* input, + Node** out, + uint32_t* out_len +) { + TSParser* parser = ts_parser_new(); + ts_parser_set_language(parser, language); + TSTree* tree = ts_parser_parse_string(parser, NULL, input, strlen(input)); + TSNode root_node = ts_tree_root_node(tree); + + char* pattern = "(comment) @comment"; + uint32_t error_offset; + TSQueryError error_type; + TSQuery* query = ts_query_new(language, pattern, strlen(pattern), &error_offset, &error_type); + TSQueryCursor* query_cursor = ts_query_cursor_new(); + ts_query_cursor_exec(query_cursor, query, root_node); + + TSQueryMatch query_match; + uint32_t n_max = 1024; + *out = malloc(sizeof(Node) * n_max); + Node* node = *out; + uint32_t n = 0; + while (ts_query_cursor_next_match(query_cursor, &query_match)) { + if (n >= n_max) { + n_max *= 2; + *out = realloc(*out, sizeof(Node) * n_max); + node = *out + n; + } + TSNode captured_node = query_match.captures[0].node; + node->start_byte = ts_node_start_byte(captured_node); + node->end_byte = ts_node_end_byte(captured_node); + node->start_point = ts_node_start_point(captured_node); + node->end_point = ts_node_end_point(captured_node); + node++; n++; + } + *out_len = n; + + ts_query_cursor_delete(query_cursor); + ts_query_delete(query); + ts_tree_delete(tree); + ts_parser_delete(parser); +} + +void ts_tree_root_node_p(TSTree *tree, TSNode *node) { + (*node) = ts_tree_root_node(tree); +} + +uint32_t ts_node_named_child_count_p(TSNode *node) { + return ts_node_named_child_count(*node); +} + +uint32_t ts_node_start_byte_p(TSNode *node) { + return ts_node_start_byte(*node); +} + +uint32_t ts_node_end_byte_p(TSNode *node) { + return ts_node_end_byte(*node); +} + +uint32_t ts_node_start_point_p(TSNode *node, TSPoint *point) { + (*point) = ts_node_start_point(*node); +} + +uint32_t ts_node_end_point_p(TSNode *node, TSPoint *point) { + (*point) = ts_node_end_point(*node); +} + +const char* ts_node_type_p(TSNode *node) { + return ts_node_type(*node); +} + +void ts_node_named_child_p(TSNode* self, uint32_t child_index, TSNode* node) { + (*node) = ts_node_named_child(*self, child_index); +} |