aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--anissue.cabal2
-rw-r--r--app/Comment.hs114
-rw-r--r--app/Issue.hs7
-rw-r--r--app/TreeSitter.hs87
-rw-r--r--app/TreeSitter/bridge.c34
-rw-r--r--app/tree_sitter.c85
6 files changed, 170 insertions, 159 deletions
diff --git a/anissue.cabal b/anissue.cabal
index 7043c79..571a577 100644
--- a/anissue.cabal
+++ b/anissue.cabal
@@ -154,4 +154,4 @@ executable anissue
TypeFamilies
ViewPatterns
- c-sources: app/TreeSitter/bridge.c
+ c-sources: app/tree_sitter.c
diff --git a/app/Comment.hs b/app/Comment.hs
index 63f610a..2769c83 100644
--- a/app/Comment.hs
+++ b/app/Comment.hs
@@ -12,7 +12,6 @@ where
import Comment.Language
import Control.Applicative (liftA2)
import Control.Exception (catch)
-import Control.Monad
import Data.Binary (Binary)
import Data.ByteString qualified as B
import Data.ByteString.Lazy qualified as LB
@@ -23,12 +22,10 @@ import Data.Ord (comparing)
import Data.Text qualified as T
import Data.Text.Encoding qualified as T
import Exception qualified as E
-import Foreign.C.String
-import Foreign.Marshal.Alloc (free, malloc)
-import Foreign.Ptr (Ptr, nullPtr)
+import Foreign.Marshal.Alloc (alloca, free)
+import Foreign.Marshal.Array (peekArray)
import Foreign.Storable
import GHC.Generics (Generic)
-import GHC.Int (Int64)
import Git qualified
import System.FilePath (takeExtension)
import TreeSitter qualified as S
@@ -36,8 +33,8 @@ import TreeSitter qualified as S
data Comment = Comment
{ text :: T.Text,
language :: Language,
- startByte :: Int64,
- endByte :: Int64,
+ startByte :: Int,
+ endByte :: Int,
startPoint :: Point,
endPoint :: Point,
filePath :: FilePath
@@ -90,75 +87,44 @@ getComments commitHash filePath =
| otherwise = go (N.singleton x : ass) xs
extractComments :: FilePath -> Language -> B.ByteString -> IO [Comment]
-extractComments filePath language str' = do
- S.withParser (parser language) $ \parser -> do
- B.useAsCStringLen str' $ \(str, len) -> do
- tree <- S.ts_parser_parse_string parser nullPtr str (fromIntegral len)
- node <- malloc
- S.ts_tree_root_node tree node
- x <-
- mapM
- ( \n' -> do
- startByte <- fromIntegral <$> S.ts_node_start_byte n'
- endByte <- fromIntegral <$> S.ts_node_end_byte n'
- let text =
- T.decodeUtf8
- . B.take (fromIntegral endByte - fromIntegral startByte)
- . B.drop (fromIntegral startByte)
- $ str'
-
- startPoint <- do
- point <- malloc
- S.ts_node_start_point n' point
- S.Point {..} <- peek point
- free point
- pure
- Point
- { row = fromIntegral row,
- column = fromIntegral column
- }
- endPoint <- do
- point <- malloc
- S.ts_node_end_point n' point
- S.Point {..} <- peek point
- free point
- pure
- Point
- { row = fromIntegral row,
- column = fromIntegral column
+extractComments filePath language str' =
+ alloca $ \nodesPtrPtr -> do
+ alloca $ \numNodesPtr -> do
+ B.useAsCString str' $ \str ->
+ S.extract_comments
+ (parser language)
+ str
+ nodesPtrPtr
+ numNodesPtr
+ numNodes <- peek numNodesPtr
+ nodesPtr <- peek nodesPtrPtr
+ nodes <- peekArray (fromIntegral numNodes) nodesPtr
+ free nodesPtr
+ pure $
+ map
+ ( \node ->
+ let startByte = fromIntegral node.startByte
+ endByte = fromIntegral node.endByte
+ in Comment
+ { startPoint =
+ Point
+ { row = fromIntegral node.startPoint.row + 1,
+ column = fromIntegral node.startPoint.column + 1
+ },
+ endPoint =
+ Point
+ { row = fromIntegral node.endPoint.row + 1,
+ column = fromIntegral node.endPoint.column + 1
+ },
+ text =
+ T.decodeUtf8
+ . B.take (endByte - startByte)
+ . B.drop startByte
+ $ str',
+ ..
}
-
- pure Comment {..}
)
- =<< (commentsFromNodeRec language node)
- free node
- pure x
-
-commentsFromNodeRec :: Language -> Ptr S.Node -> IO [Ptr S.Node]
-commentsFromNodeRec language =
- (filterM (isCommentNode language) =<<)
- . childNodesFromNodeRec
-
-isCommentNode :: Language -> Ptr S.Node -> IO Bool
-isCommentNode language n =
- fmap (`elem` (nodeTypes language)) . peekCString =<< S.ts_node_type n
-
-childNodesFromNodeRec :: Ptr S.Node -> IO [Ptr S.Node]
-childNodesFromNodeRec n = do
- ns' <- childNodesFromNode n
- ns <- concat <$> mapM childNodesFromNodeRec ns'
- pure $ n : ns
-
-childNodesFromNode :: Ptr S.Node -> IO [Ptr S.Node]
-childNodesFromNode n = do
- numChildren <- fromIntegral <$> S.ts_node_named_child_count n
- mapM
- ( \k -> do
- node <- malloc
- S.ts_node_named_child n k node
- pure node
- )
- [0 .. numChildren - 1]
+ nodes
data CommentStyle
= LineStyle T.Text
diff --git a/app/Issue.hs b/app/Issue.hs
index d58d14d..2743f45 100644
--- a/app/Issue.hs
+++ b/app/Issue.hs
@@ -23,7 +23,6 @@ import Data.Text.Encoding qualified as T
import Data.Text.IO qualified as T
import Data.Time.Clock (UTCTime (utctDay))
import GHC.Generics (Generic)
-import GHC.Int (Int64)
import GHC.Records (HasField (..))
import Git (Author (..), Commit (..))
import Git qualified as Git
@@ -40,8 +39,8 @@ data Issue = Issue
title :: T.Text,
file :: FilePath,
provenance :: Provenance,
- startByte :: Int64,
- endByte :: Int64,
+ startByte :: Int,
+ endByte :: Int,
startPoint :: G.Point,
endPoint :: G.Point,
tags :: [Tag],
@@ -100,7 +99,7 @@ instance HasField "id" Issue T.Text where
getText :: Issue -> IO T.Text
getText (Issue {..}) =
- T.decodeUtf8 . LB.toStrict . LB.take (endByte - startByte) . LB.drop startByte
+ T.decodeUtf8 . LB.toStrict . LB.take (fromIntegral (endByte - startByte)) . LB.drop (fromIntegral startByte)
<$> Git.readTextFileOfBS commitHash file
replaceText :: Issue -> T.Text -> IO ()
diff --git a/app/TreeSitter.hs b/app/TreeSitter.hs
index e911d1b..230fefc 100644
--- a/app/TreeSitter.hs
+++ b/app/TreeSitter.hs
@@ -1,65 +1,60 @@
-module TreeSitter where
-
--- | References: [tree-sitter/api.h](https://github.com/tree-sitter/tree-sitter/blob/master/lib/include/tree_sitter/api.h)
+module TreeSitter
+ ( Language,
+ Node (..),
+ Point (..),
+ extract_comments,
+ tree_sitter_haskell,
+ )
+where
import Foreign.C.String (CString)
import Foreign.C.Types (CInt (..))
import Foreign.Ptr (Ptr)
-import Foreign.Storable (Storable (..), peek)
-
-data Parser
+import Foreign.Storable (Storable (..))
data Language
-data Tree = Tree
-
data Node = Node
+ { startPoint :: Point,
+ endPoint :: Point,
+ startByte :: CInt,
+ endByte :: CInt
+ }
+ deriving (Show, Eq)
instance Storable Node where
- sizeOf _ = 32
+ alignment _ = 8
+ sizeOf _ = 24
+ peek ptr =
+ Node
+ <$> peekByteOff ptr 0
+ <*> peekByteOff ptr 8
+ <*> peekByteOff ptr 16
+ <*> peekByteOff ptr 20
+ poke ptr (Node {..}) = do
+ pokeByteOff ptr 0 startPoint
+ pokeByteOff ptr 8 endPoint
+ pokeByteOff ptr 16 startByte
+ pokeByteOff ptr 20 endByte
data Point = Point
{ row :: CInt,
column :: CInt
- } deriving (Show)
+ }
+ deriving (Show, Eq)
instance Storable Point where
+ alignment _ = 4
sizeOf _ = 8
- alignment _ = 8
- peek p = Point <$> peekByteOff p 0 <*> peekByteOff p 4
-
-withParser :: Ptr Language -> (Ptr Parser -> IO a) -> IO a
-withParser l f = do
- p <- ts_parser_new
- ts_parser_set_language p l
- x <- f p
- ts_parser_delete p
- pure x
-
-foreign import ccall unsafe "ts_node_start_point_p" ts_node_start_point :: Ptr Node -> Ptr Point -> IO ()
-
-foreign import ccall unsafe "ts_node_end_point_p" ts_node_end_point :: Ptr Node -> Ptr Point -> IO ()
-
-foreign import ccall unsafe "ts_node_start_byte_p" ts_node_start_byte :: Ptr Node -> IO CInt
-
-foreign import ccall unsafe "ts_node_end_byte_p" ts_node_end_byte :: Ptr Node -> IO CInt
-
-foreign import ccall unsafe "ts_node_type_p" ts_node_type :: Ptr Node -> IO CString
-
-foreign import ccall unsafe "ts_node_named_child_p" ts_node_named_child :: Ptr Node -> CInt -> Ptr Node -> IO ()
-
-foreign import ccall unsafe "ts_node_named_child_count_p" ts_node_named_child_count :: Ptr Node -> IO CInt
-
-foreign import ccall unsafe "ts_tree_root_node_p" ts_tree_root_node :: Ptr Tree -> Ptr Node -> IO ()
-
-foreign import ccall unsafe "ts_tree_delete" ts_tree_delete :: Ptr Tree -> IO ()
-
-foreign import ccall unsafe "ts_parser_parse_string" ts_parser_parse_string :: Ptr Parser -> Ptr Tree -> CString -> CInt -> IO (Ptr Tree)
-
-foreign import ccall unsafe "ts_parser_new" ts_parser_new :: IO (Ptr Parser)
-
-foreign import ccall unsafe "ts_parser_delete" ts_parser_delete :: Ptr Parser -> IO ()
-
-foreign import ccall unsafe "ts_parser_set_language" ts_parser_set_language :: Ptr Parser -> Ptr Language -> IO ()
+ peek ptr = Point <$> peekByteOff ptr 0 <*> peekByteOff ptr 4
+ poke ptr (Point {..}) = pokeByteOff ptr 0 row >> pokeByteOff ptr 4 column
+
+foreign import ccall unsafe "extract_comments"
+ extract_comments ::
+ Ptr Language ->
+ CString ->
+ Ptr (Ptr Node) ->
+ Ptr CInt ->
+ IO ()
foreign import ccall unsafe "tree_sitter_haskell" tree_sitter_haskell :: Ptr Language
diff --git a/app/TreeSitter/bridge.c b/app/TreeSitter/bridge.c
deleted file mode 100644
index 904c88e..0000000
--- a/app/TreeSitter/bridge.c
+++ /dev/null
@@ -1,34 +0,0 @@
-#include "tree_sitter/api.h"
-#include "string.h"
-
-void ts_tree_root_node_p(TSTree *tree, TSNode *node) {
- (*node) = ts_tree_root_node(tree);
-}
-
-uint32_t ts_node_named_child_count_p(TSNode *node) {
- return ts_node_named_child_count(*node);
-}
-
-uint32_t ts_node_start_byte_p(TSNode *node) {
- return ts_node_start_byte(*node);
-}
-
-uint32_t ts_node_end_byte_p(TSNode *node) {
- return ts_node_end_byte(*node);
-}
-
-uint32_t ts_node_start_point_p(TSNode *node, TSPoint *point) {
- (*point) = ts_node_start_point(*node);
-}
-
-uint32_t ts_node_end_point_p(TSNode *node, TSPoint *point) {
- (*point) = ts_node_end_point(*node);
-}
-
-const char* ts_node_type_p(TSNode *node) {
- return ts_node_type(*node);
-}
-
-void ts_node_named_child_p(TSNode* self, uint32_t child_index, TSNode* node) {
- (*node) = ts_node_named_child(*self, child_index);
-}
diff --git a/app/tree_sitter.c b/app/tree_sitter.c
new file mode 100644
index 0000000..d0f9fa8
--- /dev/null
+++ b/app/tree_sitter.c
@@ -0,0 +1,85 @@
+#include "string.h"
+#include "tree_sitter/api.h"
+
+typedef struct Node {
+ TSPoint start_point;
+ TSPoint end_point;
+ uint32_t start_byte;
+ uint32_t end_byte;
+} Node;
+
+void extract_comments(
+ TSLanguage* language,
+ char* input,
+ Node** out,
+ uint32_t* out_len
+) {
+ TSParser* parser = ts_parser_new();
+ ts_parser_set_language(parser, language);
+ TSTree* tree = ts_parser_parse_string(parser, NULL, input, strlen(input));
+ TSNode root_node = ts_tree_root_node(tree);
+
+ char* pattern = "(comment) @comment";
+ uint32_t error_offset;
+ TSQueryError error_type;
+ TSQuery* query = ts_query_new(language, pattern, strlen(pattern), &error_offset, &error_type);
+ TSQueryCursor* query_cursor = ts_query_cursor_new();
+ ts_query_cursor_exec(query_cursor, query, root_node);
+
+ TSQueryMatch query_match;
+ uint32_t n_max = 1024;
+ *out = malloc(sizeof(Node) * n_max);
+ Node* node = *out;
+ uint32_t n = 0;
+ while (ts_query_cursor_next_match(query_cursor, &query_match)) {
+ if (n >= n_max) {
+ n_max *= 2;
+ *out = realloc(*out, sizeof(Node) * n_max);
+ node = *out + n;
+ }
+ TSNode captured_node = query_match.captures[0].node;
+ node->start_byte = ts_node_start_byte(captured_node);
+ node->end_byte = ts_node_end_byte(captured_node);
+ node->start_point = ts_node_start_point(captured_node);
+ node->end_point = ts_node_end_point(captured_node);
+ node++; n++;
+ }
+ *out_len = n;
+
+ ts_query_cursor_delete(query_cursor);
+ ts_query_delete(query);
+ ts_tree_delete(tree);
+ ts_parser_delete(parser);
+}
+
+void ts_tree_root_node_p(TSTree *tree, TSNode *node) {
+ (*node) = ts_tree_root_node(tree);
+}
+
+uint32_t ts_node_named_child_count_p(TSNode *node) {
+ return ts_node_named_child_count(*node);
+}
+
+uint32_t ts_node_start_byte_p(TSNode *node) {
+ return ts_node_start_byte(*node);
+}
+
+uint32_t ts_node_end_byte_p(TSNode *node) {
+ return ts_node_end_byte(*node);
+}
+
+uint32_t ts_node_start_point_p(TSNode *node, TSPoint *point) {
+ (*point) = ts_node_start_point(*node);
+}
+
+uint32_t ts_node_end_point_p(TSNode *node, TSPoint *point) {
+ (*point) = ts_node_end_point(*node);
+}
+
+const char* ts_node_type_p(TSNode *node) {
+ return ts_node_type(*node);
+}
+
+void ts_node_named_child_p(TSNode* self, uint32_t child_index, TSNode* node) {
+ (*node) = ts_node_named_child(*self, child_index);
+}