forked from purescript/purescript
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTCO.hs
More file actions
191 lines (167 loc) · 8.74 KB
/
TCO.hs
File metadata and controls
191 lines (167 loc) · 8.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
-- | This module implements tail call elimination.
module Language.PureScript.CoreImp.Optimizer.TCO (tco) where
import Prelude
import Control.Applicative (empty)
import Control.Monad (guard)
import Control.Monad.State (State, evalState, gets, modify)
import Data.Functor (($>))
import Data.Set qualified as S
import Data.Text (Text, pack)
import Language.PureScript.CoreImp.AST (AST(..), InitializerEffects(..), UnaryOperator(..), everything, everywhereTopDownM)
import Language.PureScript.AST.SourcePos (SourceSpan)
import Safe (headDef, tailSafe)
-- | Eliminate tail calls
tco :: AST -> AST
tco = flip evalState 0 . everywhereTopDownM convert where
tcoVar :: Text -> Text
tcoVar arg = "$tco_var_" <> arg
copyVar :: Text -> Text
copyVar arg = "$copy_" <> arg
tcoDoneM :: State Int Text
tcoDoneM = gets $ \count -> "$tco_done" <>
if count == 0 then "" else pack . show $ count
tcoLoop :: Text
tcoLoop = "$tco_loop"
tcoResult :: Text
tcoResult = "$tco_result"
convert :: AST -> State Int AST
convert (VariableIntroduction ss name (Just (p, fn@Function {})))
| Just trFns <- findTailRecursiveFns name arity body'
= VariableIntroduction ss name . Just . (p,) . replace <$> toLoop trFns name arity outerArgs innerArgs body'
where
innerArgs = headDef [] argss
outerArgs = concat . reverse $ tailSafe argss
arity = length argss
-- this is the number of calls, not the number of arguments, if there's
-- ever a practical difference.
(argss, body', replace) = topCollectAllFunctionArgs [] id fn
convert js = pure js
rewriteFunctionsWith :: ([Text] -> [Text]) -> [[Text]] -> (AST -> AST) -> AST -> ([[Text]], AST, AST -> AST)
rewriteFunctionsWith argMapper = collectAllFunctionArgs
where
collectAllFunctionArgs allArgs f (Function s1 ident args (Block s2 (body@(Return _ _):_))) =
collectAllFunctionArgs (args : allArgs) (\b -> f (Function s1 ident (argMapper args) (Block s2 [b]))) body
collectAllFunctionArgs allArgs f (Function ss ident args body@(Block _ _)) =
(args : allArgs, body, f . Function ss ident (argMapper args))
collectAllFunctionArgs allArgs f (Return s1 (Function s2 ident args (Block s3 [body]))) =
collectAllFunctionArgs (args : allArgs) (\b -> f (Return s1 (Function s2 ident (argMapper args) (Block s3 [b])))) body
collectAllFunctionArgs allArgs f (Return s1 (Function s2 ident args body@(Block _ _))) =
(args : allArgs, body, f . Return s1 . Function s2 ident (argMapper args))
collectAllFunctionArgs allArgs f body = (allArgs, body, f)
topCollectAllFunctionArgs :: [[Text]] -> (AST -> AST) -> AST -> ([[Text]], AST, AST -> AST)
topCollectAllFunctionArgs = rewriteFunctionsWith (map copyVar)
innerCollectAllFunctionArgs :: [[Text]] -> (AST -> AST) -> AST -> ([[Text]], AST, AST -> AST)
innerCollectAllFunctionArgs = rewriteFunctionsWith id
countReferences :: Text -> AST -> Int
countReferences ident = everything (+) match where
match :: AST -> Int
match (Var _ ident') | ident == ident' = 1
match _ = 0
-- If `ident` is a tail-recursive function, returns a set of identifiers
-- that are locally bound to functions participating in the tail recursion.
-- Otherwise, returns Nothing.
findTailRecursiveFns :: Text -> Int -> AST -> Maybe (S.Set Text)
findTailRecursiveFns ident arity js = guard (countReferences ident js > 0) *> go (S.empty, S.singleton (ident, arity))
where
go :: (S.Set Text, S.Set (Text, Int)) -> Maybe (S.Set Text)
go (known, required) =
case S.minView required of
Just (r, required') -> do
required'' <- findTailPositionDeps r js
go (S.insert (fst r) known, required' <> S.filter (not . (`S.member` known) . fst) required'')
Nothing ->
pure known
-- Returns set of identifiers (with their arities) that need to be used
-- exclusively in tail calls using their full arity in order for this
-- identifier to be considered in tail position (or Nothing if this
-- identifier is used somewhere not as a tail call with full arity).
findTailPositionDeps :: (Text, Int) -> AST -> Maybe (S.Set (Text, Int))
findTailPositionDeps (ident, arity) = allInTailPosition where
countSelfReferences = countReferences ident
allInTailPosition (Return _ expr)
| isSelfCall ident arity expr = guard (countSelfReferences expr == 1) $> S.empty
| otherwise = guard (countSelfReferences expr == 0) $> S.empty
allInTailPosition (While _ js1 body)
= guard (countSelfReferences js1 == 0) *> allInTailPosition body
allInTailPosition (For _ _ js1 js2 body)
= guard (countSelfReferences js1 == 0 && countSelfReferences js2 == 0) *> allInTailPosition body
allInTailPosition (ForIn _ _ js1 body)
= guard (countSelfReferences js1 == 0) *> allInTailPosition body
allInTailPosition (IfElse _ js1 body el)
= guard (countSelfReferences js1 == 0) *> liftA2 mappend (allInTailPosition body) (foldMapA allInTailPosition el)
allInTailPosition (Block _ body)
= foldMapA allInTailPosition body
allInTailPosition (Throw _ js1)
= guard (countSelfReferences js1 == 0) $> S.empty
allInTailPosition (ReturnNoResult _)
= pure S.empty
allInTailPosition (VariableIntroduction _ _ Nothing)
= pure S.empty
allInTailPosition (VariableIntroduction _ ident' (Just (_, js1)))
| countSelfReferences js1 == 0 = pure S.empty
| Function _ Nothing _ _ <- js1
, (argss, body, _) <- innerCollectAllFunctionArgs [] id js1
= S.insert (ident', length argss) <$> allInTailPosition body
| otherwise = empty
allInTailPosition (Assignment _ _ js1)
= guard (countSelfReferences js1 == 0) $> S.empty
allInTailPosition (Comment _ js1)
= allInTailPosition js1
allInTailPosition _
= empty
toLoop :: S.Set Text -> Text -> Int -> [Text] -> [Text] -> AST -> State Int AST
toLoop trFns ident arity outerArgs innerArgs js = do
tcoDone <- tcoDoneM
modify (+ 1)
let
markDone :: Maybe SourceSpan -> AST
markDone ss = Assignment ss (Var ss tcoDone) (BooleanLiteral ss True)
loopify :: AST -> AST
loopify (Return ss ret)
| isSelfCall ident arity ret =
let
allArgumentValues = concat $ collectArgs [] ret
in
Block ss $
zipWith (\val arg ->
Assignment ss (Var ss (tcoVar arg)) val) allArgumentValues outerArgs
++ zipWith (\val arg ->
Assignment ss (Var ss (copyVar arg)) val) (drop (length outerArgs) allArgumentValues) innerArgs
++ [ ReturnNoResult ss ]
| isIndirectSelfCall ret = Return ss ret
| otherwise = Block ss [ markDone ss, Return ss ret ]
loopify (ReturnNoResult ss) = Block ss [ markDone ss, ReturnNoResult ss ]
loopify (While ss cond body) = While ss cond (loopify body)
loopify (For ss i js1 js2 body) = For ss i js1 js2 (loopify body)
loopify (ForIn ss i js1 body) = ForIn ss i js1 (loopify body)
loopify (IfElse ss cond body el) = IfElse ss cond (loopify body) (fmap loopify el)
loopify (Block ss body) = Block ss (map loopify body)
loopify (VariableIntroduction ss f (Just (p, fn@(Function _ Nothing _ _))))
| (_, body, replace) <- innerCollectAllFunctionArgs [] id fn
, f `S.member` trFns = VariableIntroduction ss f (Just (p, replace (loopify body)))
loopify other = other
pure $ Block rootSS $
map (\arg -> VariableIntroduction rootSS (tcoVar arg) (Just (UnknownEffects, Var rootSS (copyVar arg)))) outerArgs ++
[ VariableIntroduction rootSS tcoDone (Just (UnknownEffects, BooleanLiteral rootSS False))
, VariableIntroduction rootSS tcoResult Nothing
, Function rootSS (Just tcoLoop) (outerArgs ++ innerArgs) (Block rootSS [loopify js])
, While rootSS (Unary rootSS Not (Var rootSS tcoDone))
(Block rootSS
[Assignment rootSS (Var rootSS tcoResult) (App rootSS (Var rootSS tcoLoop) (map (Var rootSS . tcoVar) outerArgs ++ map (Var rootSS . copyVar) innerArgs))])
, Return rootSS (Var rootSS tcoResult)
]
where
rootSS = Nothing
collectArgs :: [[AST]] -> AST -> [[AST]]
collectArgs acc (App _ fn args') = collectArgs (args' : acc) fn
collectArgs acc _ = acc
isIndirectSelfCall :: AST -> Bool
isIndirectSelfCall (App _ (Var _ ident') _) = ident' `S.member` trFns
isIndirectSelfCall (App _ fn _) = isIndirectSelfCall fn
isIndirectSelfCall _ = False
isSelfCall :: Text -> Int -> AST -> Bool
isSelfCall ident 1 (App _ (Var _ ident') _) = ident == ident'
isSelfCall ident arity (App _ fn _) = isSelfCall ident (arity - 1) fn
isSelfCall _ _ _ = False
foldMapA :: (Applicative f, Monoid w, Foldable t) => (a -> f w) -> t a -> f w
foldMapA f = foldr (liftA2 mappend . f) (pure mempty)