aboutsummaryrefslogtreecommitdiffstats
path: root/app/Comment.hs
blob: 7bd4ad55112e4a46ecaa9b8bf17582f17d193361 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
module Comment
  ( Comment (..),
    Point (..),
    getComments,
    extractComments,
    CommentStyle (..),
    uncomment,
    comment,
  )
where

import Comment.Language
import Control.Applicative (liftA2)
import Control.Exception (catch)
import Control.Monad
import Data.Binary (Binary)
import Data.ByteString qualified as B
import Data.ByteString.Lazy qualified as LB
import Data.List (find, sortBy)
import Data.List.NonEmpty qualified as N
import Data.Maybe (fromMaybe)
import Data.Ord (comparing)
import Data.Text qualified as T
import Data.Text.Encoding qualified as T
import Exception qualified as E
import Foreign.C.String
import Foreign.Marshal.Alloc (free, malloc)
import Foreign.Marshal.Array (mallocArray, peekArray)
import Foreign.Ptr (nullPtr)
import Foreign.Storable
import GHC.Generics (Generic)
import GHC.Int (Int64)
import Git qualified
import System.FilePath (takeExtension)
import TreeSitter.Node qualified as S
import TreeSitter.Parser qualified as S
import TreeSitter.Tree qualified as S

data Comment = Comment
  { text :: T.Text,
    language :: Language,
    startByte :: Int64,
    endByte :: Int64,
    startPoint :: Point,
    endPoint :: Point,
    filePath :: FilePath
  }
  deriving (Eq, Show)

data Point = Point
  { row :: Int,
    column :: Int
  }
  deriving (Eq, Show, Generic, Binary)

getComments :: Git.CommitHash -> FilePath -> IO [Comment]
getComments commitHash filePath =
  fmap mergeLineComments
    . (extractComments filePath language . LB.toStrict)
    =<< catch
      (Git.readTextFileOfBS commitHash filePath)
      (\(_ :: E.CannotReadFile) -> pure "")
  where
    language = fromExtension (takeExtension filePath)

    mergeLineComments :: [Comment] -> [Comment]
    mergeLineComments =
      map mergeGroup
        . chainsBy (\a b -> a.endPoint.row + 1 == b.startPoint.row)
        . sortBy (comparing (liftA2 (,) (.startByte) (.endByte)))

    mergeGroup :: N.NonEmpty Comment -> Comment
    mergeGroup css@(c N.:| cs) =
      c
        { text = T.unlines (map (.text) (c : cs)),
          startByte = first.startByte,
          endByte = last.endByte,
          startPoint = first.startPoint,
          endPoint = last.endPoint
        }
      where
        first = N.head css
        last = N.last css

    {- A version of `Data.List.groupBy` that uses the last added group-member for comparison with new candidates for the group. `Data.List.groupBy` uses the initial member for all subsequent comparisons. -}
    chainsBy :: (a -> a -> Bool) -> [a] -> [N.NonEmpty a]
    chainsBy p = reverse . map N.reverse . go []
      where
        go rs [] = rs
        go [] (x : xs) = go [N.singleton x] xs
        go (ass@((a N.:| as) : rs)) (x : xs)
          | p a x = go ((x N.:| a : as) : rs) xs
          | otherwise = go (N.singleton x : ass) xs

extractComments :: FilePath -> Language -> B.ByteString -> IO [Comment]
extractComments filePath language str' = do
  S.withParser (parser language) $ \parser -> do
    B.useAsCStringLen str' $ \(str, len) -> do
      tree <- S.ts_parser_parse_string parser nullPtr str len
      S.withRootNode tree $ \node -> do
        map
          ( \n' ->
              let startByte = fromIntegral $ S.nodeStartByte n'
                  endByte = fromIntegral $ S.nodeEndByte n'
                  text =
                    T.decodeUtf8
                      . B.take (fromIntegral endByte - fromIntegral startByte)
                      . B.drop (fromIntegral startByte)
                      $ str'
                  startPoint = fromTSPoint (S.nodeStartPoint n')
                  endPoint = fromTSPoint (S.nodeEndPoint n')

                  fromTSPoint (S.TSPoint {..}) = Point (fromIntegral pointRow + 1) (fromIntegral pointColumn + 1)
               in Comment {..}
          )
          <$> (commentsFromNodeRec language =<< peek node)

commentsFromNodeRec :: Language -> S.Node -> IO [S.Node]
commentsFromNodeRec language =
  (filterM (isCommentNode language) =<<)
    . childNodesFromNodeRec

isCommentNode :: Language -> S.Node -> IO Bool
isCommentNode language n =
  (`elem` (nodeTypes language)) <$> peekCString (S.nodeType n)

childNodesFromNodeRec :: S.Node -> IO [S.Node]
childNodesFromNodeRec n = do
  ns' <- childNodesFromNode n
  ns <- concat <$> mapM childNodesFromNodeRec ns'
  pure $ n : ns

childNodesFromNode :: S.Node -> IO [S.Node]
childNodesFromNode n = do
  let numChildren = fromIntegral (S.nodeChildCount n)
  ns <- mallocArray numChildren
  tsNode <- malloc
  poke tsNode (S.nodeTSNode n)
  S.ts_node_copy_child_nodes tsNode ns
  free tsNode
  ns' <- peekArray numChildren ns
  free ns
  pure ns'

data CommentStyle
  = LineStyle T.Text
  | BlockStyle T.Text T.Text
  deriving (Eq, Show, Generic, Binary)

comment :: CommentStyle -> T.Text -> T.Text
comment (LineStyle linePrefix) = T.unlines . map ((linePrefix <> " ") <>) . T.lines
comment (BlockStyle blockStart blockEnd) = (blockStart <>) . (<> blockEnd)

uncomment :: Language -> T.Text -> (CommentStyle, T.Text)
uncomment language rawText =
  maybe
    ( ( LineStyle (lineStart language),
        stripLineComments (lineStart language) text
      )
    )
    ( \(blockStart, blockEnd) ->
        ( BlockStyle blockStart blockEnd,
          stripBlockComment blockStart blockEnd text
        )
    )
    $ do
      (blockStarts, blockEnd) <- block language
      (,blockEnd) <$> find (`T.isPrefixOf` text) blockStarts
  where
    text = stripLines rawText
    stripLines = T.intercalate "\n" . map T.strip . T.lines

stripLineComments :: T.Text -> T.Text -> T.Text
stripLineComments lineStart text =
  onLines
    ( \line ->
        fromMaybe line . fmap T.stripStart $
          T.stripPrefix lineStart line
    )
    text
  where
    onLines f = T.intercalate "\n" . map f . T.lines

stripBlockComment :: T.Text -> T.Text -> T.Text -> T.Text
stripBlockComment blockStart blockEnd text =
  T.strip
    . (fromMaybe text . T.stripSuffix blockEnd)
    . (fromMaybe text . T.stripPrefix blockStart)
    $ text