aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar Alexander Foremny <aforemny@posteo.de>2023-12-18 02:41:56 +0100
committerLibravatar Alexander Foremny <aforemny@posteo.de>2023-12-18 05:27:40 +0100
commit10c764c022b1e46c84a3b4d3743a58bd1072b5a5 (patch)
tree9e37cf690bbeb8e430ddf4340b08f55c6fa78954
parent0d96613d9aa41f93ebb440bb1aa383456b49f28f (diff)
feat: limit the number of FFI calls for extracting comments
This replaces the tree-sitter bindings with a call to a single C function that traverses the AST. We expect the query API to be slower than manually traversing the tree for this particular use case. This will be addressed in an upcoming commit. @prerequisite-for add-languages-elm-shell-nix
-rw-r--r--anissue.cabal2
-rw-r--r--app/Comment.hs114
-rw-r--r--app/Issue.hs7
-rw-r--r--app/TreeSitter.hs87
-rw-r--r--app/TreeSitter/bridge.c34
-rw-r--r--app/tree_sitter.c85
6 files changed, 170 insertions, 159 deletions
diff --git a/anissue.cabal b/anissue.cabal
index 7043c79..571a577 100644
--- a/anissue.cabal
+++ b/anissue.cabal
@@ -154,4 +154,4 @@ executable anissue
TypeFamilies
ViewPatterns
- c-sources: app/TreeSitter/bridge.c
+ c-sources: app/tree_sitter.c
diff --git a/app/Comment.hs b/app/Comment.hs
index 63f610a..2769c83 100644
--- a/app/Comment.hs
+++ b/app/Comment.hs
@@ -12,7 +12,6 @@ where
import Comment.Language
import Control.Applicative (liftA2)
import Control.Exception (catch)
-import Control.Monad
import Data.Binary (Binary)
import Data.ByteString qualified as B
import Data.ByteString.Lazy qualified as LB
@@ -23,12 +22,10 @@ import Data.Ord (comparing)
import Data.Text qualified as T
import Data.Text.Encoding qualified as T
import Exception qualified as E
-import Foreign.C.String
-import Foreign.Marshal.Alloc (free, malloc)
-import Foreign.Ptr (Ptr, nullPtr)
+import Foreign.Marshal.Alloc (alloca, free)
+import Foreign.Marshal.Array (peekArray)
import Foreign.Storable
import GHC.Generics (Generic)
-import GHC.Int (Int64)
import Git qualified
import System.FilePath (takeExtension)
import TreeSitter qualified as S
@@ -36,8 +33,8 @@ import TreeSitter qualified as S
data Comment = Comment
{ text :: T.Text,
language :: Language,
- startByte :: Int64,
- endByte :: Int64,
+ startByte :: Int,
+ endByte :: Int,
startPoint :: Point,
endPoint :: Point,
filePath :: FilePath
@@ -90,75 +87,44 @@ getComments commitHash filePath =
| otherwise = go (N.singleton x : ass) xs
extractComments :: FilePath -> Language -> B.ByteString -> IO [Comment]
-extractComments filePath language str' = do
- S.withParser (parser language) $ \parser -> do
- B.useAsCStringLen str' $ \(str, len) -> do
- tree <- S.ts_parser_parse_string parser nullPtr str (fromIntegral len)
- node <- malloc
- S.ts_tree_root_node tree node
- x <-
- mapM
- ( \n' -> do
- startByte <- fromIntegral <$> S.ts_node_start_byte n'
- endByte <- fromIntegral <$> S.ts_node_end_byte n'
- let text =
- T.decodeUtf8
- . B.take (fromIntegral endByte - fromIntegral startByte)
- . B.drop (fromIntegral startByte)
- $ str'
-
- startPoint <- do
- point <- malloc
- S.ts_node_start_point n' point
- S.Point {..} <- peek point
- free point
- pure
- Point
- { row = fromIntegral row,
- column = fromIntegral column
- }
- endPoint <- do
- point <- malloc
- S.ts_node_end_point n' point
- S.Point {..} <- peek point
- free point
- pure
- Point
- { row = fromIntegral row,
- column = fromIntegral column
+extractComments filePath language str' =
+ alloca $ \nodesPtrPtr -> do
+ alloca $ \numNodesPtr -> do
+ B.useAsCString str' $ \str ->
+ S.extract_comments
+ (parser language)
+ str
+ nodesPtrPtr
+ numNodesPtr
+ numNodes <- peek numNodesPtr
+ nodesPtr <- peek nodesPtrPtr
+ nodes <- peekArray (fromIntegral numNodes) nodesPtr
+ free nodesPtr
+ pure $
+ map
+ ( \node ->
+ let startByte = fromIntegral node.startByte
+ endByte = fromIntegral node.endByte
+ in Comment
+ { startPoint =
+ Point
+ { row = fromIntegral node.startPoint.row + 1,
+ column = fromIntegral node.startPoint.column + 1
+ },
+ endPoint =
+ Point
+ { row = fromIntegral node.endPoint.row + 1,
+ column = fromIntegral node.endPoint.column + 1
+ },
+ text =
+ T.decodeUtf8
+ . B.take (endByte - startByte)
+ . B.drop startByte
+ $ str',
+ ..
}
-
- pure Comment {..}
)
- =<< (commentsFromNodeRec language node)
- free node
- pure x
-
-commentsFromNodeRec :: Language -> Ptr S.Node -> IO [Ptr S.Node]
-commentsFromNodeRec language =
- (filterM (isCommentNode language) =<<)
- . childNodesFromNodeRec
-
-isCommentNode :: Language -> Ptr S.Node -> IO Bool
-isCommentNode language n =
- fmap (`elem` (nodeTypes language)) . peekCString =<< S.ts_node_type n
-
-childNodesFromNodeRec :: Ptr S.Node -> IO [Ptr S.Node]
-childNodesFromNodeRec n = do
- ns' <- childNodesFromNode n
- ns <- concat <$> mapM childNodesFromNodeRec ns'
- pure $ n : ns
-
-childNodesFromNode :: Ptr S.Node -> IO [Ptr S.Node]
-childNodesFromNode n = do
- numChildren <- fromIntegral <$> S.ts_node_named_child_count n
- mapM
- ( \k -> do
- node <- malloc
- S.ts_node_named_child n k node
- pure node
- )
- [0 .. numChildren - 1]
+ nodes
data CommentStyle
= LineStyle T.Text
diff --git a/app/Issue.hs b/app/Issue.hs
index d58d14d..2743f45 100644
--- a/app/Issue.hs
+++ b/app/Issue.hs
@@ -23,7 +23,6 @@ import Data.Text.Encoding qualified as T
import Data.Text.IO qualified as T
import Data.Time.Clock (UTCTime (utctDay))
import GHC.Generics (Generic)
-import GHC.Int (Int64)
import GHC.Records (HasField (..))
import Git (Author (..), Commit (..))
import Git qualified as Git
@@ -40,8 +39,8 @@ data Issue = Issue
title :: T.Text,
file :: FilePath,
provenance :: Provenance,
- startByte :: Int64,
- endByte :: Int64,
+ startByte :: Int,
+ endByte :: Int,
startPoint :: G.Point,
endPoint :: G.Point,
tags :: [Tag],
@@ -100,7 +99,7 @@ instance HasField "id" Issue T.Text where
getText :: Issue -> IO T.Text
getText (Issue {..}) =
- T.decodeUtf8 . LB.toStrict . LB.take (endByte - startByte) . LB.drop startByte
+ T.decodeUtf8 . LB.toStrict . LB.take (fromIntegral (endByte - startByte)) . LB.drop (fromIntegral startByte)
<$> Git.readTextFileOfBS commitHash file
replaceText :: Issue -> T.Text -> IO ()
diff --git a/app/TreeSitter.hs b/app/TreeSitter.hs
index e911d1b..230fefc 100644
--- a/app/TreeSitter.hs
+++ b/app/TreeSitter.hs
@@ -1,65 +1,60 @@
-module TreeSitter where
-
--- | References: [tree-sitter/api.h](https://github.com/tree-sitter/tree-sitter/blob/master/lib/include/tree_sitter/api.h)
+module TreeSitter
+ ( Language,
+ Node (..),
+ Point (..),
+ extract_comments,
+ tree_sitter_haskell,
+ )
+where
import Foreign.C.String (CString)
import Foreign.C.Types (CInt (..))
import Foreign.Ptr (Ptr)
-import Foreign.Storable (Storable (..), peek)
-
-data Parser
+import Foreign.Storable (Storable (..))
data Language
-data Tree = Tree
-
data Node = Node
+ { startPoint :: Point,
+ endPoint :: Point,
+ startByte :: CInt,
+ endByte :: CInt
+ }
+ deriving (Show, Eq)
instance Storable Node where
- sizeOf _ = 32
+ alignment _ = 8
+ sizeOf _ = 24
+ peek ptr =
+ Node
+ <$> peekByteOff ptr 0
+ <*> peekByteOff ptr 8
+ <*> peekByteOff ptr 16
+ <*> peekByteOff ptr 20
+ poke ptr (Node {..}) = do
+ pokeByteOff ptr 0 startPoint
+ pokeByteOff ptr 8 endPoint
+ pokeByteOff ptr 16 startByte
+ pokeByteOff ptr 20 endByte
data Point = Point
{ row :: CInt,
column :: CInt
- } deriving (Show)
+ }
+ deriving (Show, Eq)
instance Storable Point where
+ alignment _ = 4
sizeOf _ = 8
- alignment _ = 8
- peek p = Point <$> peekByteOff p 0 <*> peekByteOff p 4
-
-withParser :: Ptr Language -> (Ptr Parser -> IO a) -> IO a
-withParser l f = do
- p <- ts_parser_new
- ts_parser_set_language p l
- x <- f p
- ts_parser_delete p
- pure x
-
-foreign import ccall unsafe "ts_node_start_point_p" ts_node_start_point :: Ptr Node -> Ptr Point -> IO ()
-
-foreign import ccall unsafe "ts_node_end_point_p" ts_node_end_point :: Ptr Node -> Ptr Point -> IO ()
-
-foreign import ccall unsafe "ts_node_start_byte_p" ts_node_start_byte :: Ptr Node -> IO CInt
-
-foreign import ccall unsafe "ts_node_end_byte_p" ts_node_end_byte :: Ptr Node -> IO CInt
-
-foreign import ccall unsafe "ts_node_type_p" ts_node_type :: Ptr Node -> IO CString
-
-foreign import ccall unsafe "ts_node_named_child_p" ts_node_named_child :: Ptr Node -> CInt -> Ptr Node -> IO ()
-
-foreign import ccall unsafe "ts_node_named_child_count_p" ts_node_named_child_count :: Ptr Node -> IO CInt
-
-foreign import ccall unsafe "ts_tree_root_node_p" ts_tree_root_node :: Ptr Tree -> Ptr Node -> IO ()
-
-foreign import ccall unsafe "ts_tree_delete" ts_tree_delete :: Ptr Tree -> IO ()
-
-foreign import ccall unsafe "ts_parser_parse_string" ts_parser_parse_string :: Ptr Parser -> Ptr Tree -> CString -> CInt -> IO (Ptr Tree)
-
-foreign import ccall unsafe "ts_parser_new" ts_parser_new :: IO (Ptr Parser)
-
-foreign import ccall unsafe "ts_parser_delete" ts_parser_delete :: Ptr Parser -> IO ()
-
-foreign import ccall unsafe "ts_parser_set_language" ts_parser_set_language :: Ptr Parser -> Ptr Language -> IO ()
+ peek ptr = Point <$> peekByteOff ptr 0 <*> peekByteOff ptr 4
+ poke ptr (Point {..}) = pokeByteOff ptr 0 row >> pokeByteOff ptr 4 column
+
+foreign import ccall unsafe "extract_comments"
+ extract_comments ::
+ Ptr Language ->
+ CString ->
+ Ptr (Ptr Node) ->
+ Ptr CInt ->
+ IO ()
foreign import ccall unsafe "tree_sitter_haskell" tree_sitter_haskell :: Ptr Language
diff --git a/app/TreeSitter/bridge.c b/app/TreeSitter/bridge.c
deleted file mode 100644
index 904c88e..0000000
--- a/app/TreeSitter/bridge.c
+++ /dev/null
@@ -1,34 +0,0 @@
-#include "tree_sitter/api.h"
-#include "string.h"
-
-void ts_tree_root_node_p(TSTree *tree, TSNode *node) {
- (*node) = ts_tree_root_node(tree);
-}
-
-uint32_t ts_node_named_child_count_p(TSNode *node) {
- return ts_node_named_child_count(*node);
-}
-
-uint32_t ts_node_start_byte_p(TSNode *node) {
- return ts_node_start_byte(*node);
-}
-
-uint32_t ts_node_end_byte_p(TSNode *node) {
- return ts_node_end_byte(*node);
-}
-
-uint32_t ts_node_start_point_p(TSNode *node, TSPoint *point) {
- (*point) = ts_node_start_point(*node);
-}
-
-uint32_t ts_node_end_point_p(TSNode *node, TSPoint *point) {
- (*point) = ts_node_end_point(*node);
-}
-
-const char* ts_node_type_p(TSNode *node) {
- return ts_node_type(*node);
-}
-
-void ts_node_named_child_p(TSNode* self, uint32_t child_index, TSNode* node) {
- (*node) = ts_node_named_child(*self, child_index);
-}
diff --git a/app/tree_sitter.c b/app/tree_sitter.c
new file mode 100644
index 0000000..d0f9fa8
--- /dev/null
+++ b/app/tree_sitter.c
@@ -0,0 +1,85 @@
+#include "string.h"
+#include "tree_sitter/api.h"
+
+typedef struct Node {
+ TSPoint start_point;
+ TSPoint end_point;
+ uint32_t start_byte;
+ uint32_t end_byte;
+} Node;
+
+void extract_comments(
+ TSLanguage* language,
+ char* input,
+ Node** out,
+ uint32_t* out_len
+) {
+ TSParser* parser = ts_parser_new();
+ ts_parser_set_language(parser, language);
+ TSTree* tree = ts_parser_parse_string(parser, NULL, input, strlen(input));
+ TSNode root_node = ts_tree_root_node(tree);
+
+ char* pattern = "(comment) @comment";
+ uint32_t error_offset;
+ TSQueryError error_type;
+ TSQuery* query = ts_query_new(language, pattern, strlen(pattern), &error_offset, &error_type);
+ TSQueryCursor* query_cursor = ts_query_cursor_new();
+ ts_query_cursor_exec(query_cursor, query, root_node);
+
+ TSQueryMatch query_match;
+ uint32_t n_max = 1024;
+ *out = malloc(sizeof(Node) * n_max);
+ Node* node = *out;
+ uint32_t n = 0;
+ while (ts_query_cursor_next_match(query_cursor, &query_match)) {
+ if (n >= n_max) {
+ n_max *= 2;
+ *out = realloc(*out, sizeof(Node) * n_max);
+ node = *out + n;
+ }
+ TSNode captured_node = query_match.captures[0].node;
+ node->start_byte = ts_node_start_byte(captured_node);
+ node->end_byte = ts_node_end_byte(captured_node);
+ node->start_point = ts_node_start_point(captured_node);
+ node->end_point = ts_node_end_point(captured_node);
+ node++; n++;
+ }
+ *out_len = n;
+
+ ts_query_cursor_delete(query_cursor);
+ ts_query_delete(query);
+ ts_tree_delete(tree);
+ ts_parser_delete(parser);
+}
+
+void ts_tree_root_node_p(TSTree *tree, TSNode *node) {
+ (*node) = ts_tree_root_node(tree);
+}
+
+uint32_t ts_node_named_child_count_p(TSNode *node) {
+ return ts_node_named_child_count(*node);
+}
+
+uint32_t ts_node_start_byte_p(TSNode *node) {
+ return ts_node_start_byte(*node);
+}
+
+uint32_t ts_node_end_byte_p(TSNode *node) {
+ return ts_node_end_byte(*node);
+}
+
+uint32_t ts_node_start_point_p(TSNode *node, TSPoint *point) {
+ (*point) = ts_node_start_point(*node);
+}
+
+uint32_t ts_node_end_point_p(TSNode *node, TSPoint *point) {
+ (*point) = ts_node_end_point(*node);
+}
+
+const char* ts_node_type_p(TSNode *node) {
+ return ts_node_type(*node);
+}
+
+void ts_node_named_child_p(TSNode* self, uint32_t child_index, TSNode* node) {
+ (*node) = ts_node_named_child(*self, child_index);
+}