-
-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #90 from grin-compiler/32-trf-case-hoisting
Extended Syntax: case hoisting
- Loading branch information
Showing
3 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
123 changes: 123 additions & 0 deletions
123
grin/src/Transformations/ExtendedSyntax/Optimising/CaseHoisting.hs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
{-# LANGUAGE LambdaCase, TupleSections #-} | ||
module Transformations.ExtendedSyntax.Optimising.CaseHoisting where | ||
|
||
import Control.Monad | ||
import Control.Comonad | ||
import Control.Comonad.Cofree | ||
import Data.Functor.Foldable as Foldable | ||
import qualified Data.Foldable | ||
|
||
import Data.Map (Map) | ||
import qualified Data.Map as Map | ||
import Data.Set (Set) | ||
import qualified Data.Set as Set | ||
import qualified Data.Vector as Vector | ||
import Data.Bifunctor (first) | ||
|
||
import Grin.ExtendedSyntax.Grin | ||
import Grin.ExtendedSyntax.TypeEnv | ||
import Transformations.ExtendedSyntax.Util | ||
import Transformations.ExtendedSyntax.Names | ||
|
||
{- | ||
IDEA: | ||
If Alt had name then the HPT could calculate it's return type and store in TypeEnv | ||
-} | ||
|
||
getReturnTagSet :: TypeEnv -> Exp -> Maybe (Set Tag) | ||
getReturnTagSet typeEnv = cata folder where | ||
folder exp = case exp of | ||
EBindF _ _ ts -> ts | ||
SBlockF ts -> ts | ||
AltF _ _ ts -> ts | ||
ECaseF _ alts -> mconcat <$> sequence alts | ||
|
||
SReturnF val | ||
| Just (T_NodeSet ns) <- mTypeOfValTE typeEnv val | ||
-> Just (Map.keysSet ns) | ||
|
||
SAppF name _ | ||
| T_NodeSet ns <- fst $ functionType typeEnv name | ||
-> Just (Map.keysSet ns) | ||
|
||
SFetchF name | ||
| T_SimpleType (T_Location locs) <- variableType typeEnv name | ||
-> Just (mconcat [Map.keysSet (_location typeEnv Vector.! loc) | loc <- locs]) | ||
|
||
_ -> Nothing | ||
|
||
|
||
caseHoisting :: TypeEnv -> Exp -> (Exp, ExpChanges) | ||
caseHoisting typeEnv exp = first fst $ evalNameM exp $ histoM folder exp where | ||
|
||
folder :: ExpF (Cofree ExpF (Exp, Set Name)) -> NameM (Exp, Set Name) | ||
folder exp = case exp of | ||
-- middle case | ||
EBindF ((ECase val alts1, leftUse) :< _) (VarPat lpatName) | ||
(_ :< (EBindF ((ECase varName alts2, caseUse) :< _) lpat ((rightExp, rightUse) :< _))) | ||
| lpatName == varName | ||
, Just alts1Types <- sequence $ map (getReturnTagSet typeEnv) alts1 | ||
, Just matchList <- disjointMatch (zip alts1Types alts1) alts2 | ||
, Set.notMember varName rightUse -- allow only linear variables ; that are not used later | ||
-> do | ||
hoistedAlts <- mapM (hoistAlts lpatName) matchList | ||
pure (EBind (ECase val hoistedAlts) lpat rightExp, Set.delete varName $ mconcat [leftUse, caseUse, rightUse]) | ||
|
||
-- last case | ||
EBindF ((ECase val alts1, leftUse) :< _) (VarPat lpatName) ((ECase varName alts2, rightUse) :< _) | ||
| lpatName == varName | ||
, Just alts1Types <- sequence $ map (getReturnTagSet typeEnv) alts1 | ||
, Just matchList <- disjointMatch (zip alts1Types alts1) alts2 | ||
-> do | ||
hoistedAlts <- mapM (hoistAlts lpatName) matchList | ||
pure (ECase val hoistedAlts, Set.delete varName $ mconcat [leftUse, rightUse]) | ||
|
||
_ -> let useSub = Data.Foldable.fold (snd . extract <$> exp) | ||
useExp = foldNameUseExpF Set.singleton exp | ||
in pure (embed (fst . extract <$> exp), mconcat [useSub, useExp]) | ||
|
||
hoistAlts :: Name -> (Alt, Alt) -> NameM Alt | ||
hoistAlts lpatName (Alt cpat1 altName1 alt1, Alt cpat2 altName2 alt2) = do | ||
freshLPatName <- deriveNewName lpatName | ||
let nameMap = Map.singleton lpatName freshLPatName | ||
(freshAlt2, _) <- refreshNames nameMap $ | ||
EBind (SReturn $ Var freshLPatName) (VarPat altName2) alt2 | ||
pure . Alt cpat1 altName1 $ EBind (SBlock alt1) (VarPat freshLPatName) freshAlt2 | ||
|
||
disjointMatch :: [(Set Tag, Alt)] -> [Alt] -> Maybe [(Alt, Alt)] | ||
disjointMatch tsAlts1 alts2 | ||
| Just (defaults, tagMap) <- mconcat <$> mapM groupByCPats alts2 | ||
, length defaults <= 1 | ||
, Just (altPairs, _, _) <- Data.Foldable.foldrM (matchAlt tagMap) ([], defaults, Set.empty) tsAlts1 | ||
= Just altPairs | ||
disjointMatch _ _ = Nothing | ||
|
||
groupByCPats :: Alt -> Maybe ([Alt], Map Tag Alt) | ||
groupByCPats alt@(Alt cpat _ _) = case cpat of | ||
DefaultPat -> Just ([alt], mempty) | ||
NodePat tag _ -> Just ([], Map.singleton tag alt) | ||
_ -> Nothing | ||
|
||
matchAlt :: Map Tag Alt -> (Set Tag, Alt) -> ([(Alt, Alt)], [Alt], Set Tag) -> Maybe ([(Alt, Alt)], [Alt], Set Tag) | ||
matchAlt tagMap (ts, alt1) (matchList, defaults, coveredTags) | ||
-- regular node pattern | ||
| Set.size ts == 1 | ||
, tag <- Set.findMin ts | ||
, Set.notMember tag coveredTags | ||
, Just alt2 <- Map.lookup tag tagMap | ||
= Just ((alt1, alt2):matchList, defaults, Set.insert tag coveredTags) | ||
|
||
-- default can handle this | ||
| defaultAlt:[] <- defaults | ||
, Data.Foldable.all (flip Set.notMember coveredTags) ts | ||
= Just ((alt1, defaultAlt):matchList, [], coveredTags `mappend` ts) | ||
|
||
| otherwise = Nothing | ||
|
||
{- | ||
TODO: | ||
- add cloned variables to TypeEnv | ||
done - ignore non linear scrutinee | ||
IDEA: | ||
this could be supported if product type was available in GRIN then the second case could return from the hoisted case with a pair of the original two case results | ||
-} |
203 changes: 203 additions & 0 deletions
203
grin/test/Transformations/ExtendedSyntax/Optimising/CaseHoistingSpec.hs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} | ||
module Transformations.ExtendedSyntax.Optimising.CaseHoistingSpec where | ||
|
||
import Transformations.ExtendedSyntax.Optimising.CaseHoisting | ||
|
||
import Test.Hspec | ||
|
||
import Grin.ExtendedSyntax.TH | ||
import Grin.ExtendedSyntax.TypeCheck | ||
import Test.ExtendedSyntax.Assertions | ||
import Transformations.ExtendedSyntax.Names (ExpChanges(..)) | ||
|
||
|
||
runTests :: IO () | ||
runTests = hspec spec | ||
|
||
spec :: Spec | ||
spec = do | ||
it "last case" $ do | ||
let before = [prog| | ||
grinMain = | ||
v <- pure (CNil) | ||
u <- case v of | ||
(CNil) @ alt1 -> pure (CNil) | ||
(CCons a1 b1) @ alt2 -> pure (CCons a1 b1) | ||
case u of | ||
(CNil) @ alt3 -> pure alt3 | ||
(CCons a2 b2) @ alt4 -> pure (CNil) | ||
|] | ||
let after = [prog| | ||
grinMain = | ||
v <- pure (CNil) | ||
case v of | ||
(CNil) @ alt1 -> | ||
u.0 <- do | ||
pure (CNil) | ||
alt3.0 <- pure u.0 | ||
pure alt3.0 | ||
(CCons a1 b1) @ alt2 -> | ||
u.1 <- do | ||
pure (CCons a1 b1) | ||
alt4.0 <- pure u.1 | ||
pure (CNil) | ||
|] | ||
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) | ||
|
||
it "middle case" $ do | ||
let before = [prog| | ||
grinMain = | ||
v <- pure (CNil) | ||
u <- case v of | ||
(CNil) @ alt1 -> pure (CNil) | ||
(CCons a1 b1) @ alt2 -> pure (CCons a1 b1) | ||
r <- case u of | ||
(CNil) @ alt3 -> pure 1 | ||
(CCons a2 b2) @ alt4 -> pure 2 | ||
pure r | ||
|] | ||
let after = [prog| | ||
grinMain = | ||
v <- pure (CNil) | ||
r <- case v of | ||
(CNil) @ alt1 -> | ||
u.0 <- do | ||
pure (CNil) | ||
alt3.0 <- pure u.0 | ||
pure 1 | ||
(CCons a1 b1) @ alt2 -> | ||
u.1 <- do | ||
pure (CCons a1 b1) | ||
alt4.0 <- pure u.1 | ||
pure 2 | ||
pure r | ||
|] | ||
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) | ||
|
||
it "default pattern" $ do | ||
let before = [prog| | ||
grinMain = | ||
v <- pure (CNil) | ||
u <- case v of | ||
(CNil) @ alt1 -> pure (CNil) | ||
(CCons a1 b1) @ alt2 -> pure (CCons a1 b1) | ||
r <- case u of | ||
(CNil) @ alt3 -> pure (CNil) | ||
#default @ alt4 -> pure alt4 | ||
pure r | ||
|] | ||
let after = [prog| | ||
grinMain = | ||
v <- pure (CNil) | ||
r <- case v of | ||
(CNil) @ alt1 -> | ||
u.0 <- do | ||
pure (CNil) | ||
alt3.0 <- pure u.0 | ||
pure (CNil) | ||
(CCons a1 b1) @ alt2 -> | ||
u.1 <- do | ||
pure (CCons a1 b1) | ||
alt4.0 <- pure u.1 | ||
pure alt4.0 | ||
pure r | ||
|] | ||
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) | ||
|
||
it "case chain + no code duplication" $ do | ||
let before = [prog| | ||
grinMain = | ||
v <- pure 1 | ||
u <- case v of | ||
0 @ alt1 -> pure (CNil) | ||
1 @ alt2 -> pure (CCons v v) | ||
r <- case u of | ||
(CNil) @ alt3 -> pure (CEmpty) | ||
#default @ alt4 -> pure u | ||
q <- case r of | ||
(CVoid) @ alt5 -> | ||
pure (CEmpty) | ||
#default @ alt6 -> | ||
k0 <- pure 777 | ||
_1 <- _prim_int_print k0 | ||
pure r | ||
pure q | ||
|] | ||
let after = [prog| | ||
grinMain = | ||
v <- pure 1 | ||
r <- case v of | ||
0 @ alt1 -> | ||
u.0 <- do | ||
pure (CNil) | ||
alt3.0 <- pure u.0 | ||
pure (CEmpty) | ||
1 @ alt2 -> | ||
u.1 <- do | ||
pure (CCons v v) | ||
alt4.0 <- pure u.1 | ||
pure u.1 | ||
q <- case r of | ||
(CVoid) @ alt5 -> | ||
pure (CEmpty) | ||
#default @ alt6 -> | ||
k0 <- pure 777 | ||
_1 <- _prim_int_print k0 | ||
pure r | ||
pure q | ||
|] | ||
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) | ||
|
||
it "default chain" $ do | ||
let before = [prog| | ||
grinMain = | ||
v <- pure 1 | ||
u <- case v of | ||
0 @ alt1 -> pure (CNil) | ||
1 @ alt2 -> pure (CCons v v) | ||
r <- case u of | ||
#default @ alt3 -> pure u | ||
q <- case r of | ||
#default @ alt4 -> pure r | ||
pure q | ||
|] | ||
let after = [prog| | ||
grinMain = | ||
v <- pure 1 | ||
u <- case v of | ||
0 @ alt1 -> | ||
pure (CNil) | ||
1 @ alt2 -> | ||
pure (CCons v v) | ||
q <- case u of | ||
#default @ alt3 -> | ||
r.0 <- do | ||
pure u | ||
alt4.0 <- pure r.0 | ||
pure r.0 | ||
pure q | ||
|] | ||
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) | ||
|
||
it "ignore non linear variable" $ do | ||
let before = [prog| | ||
grinMain = | ||
v <- pure (CNil) | ||
u <- case v of | ||
#default @ alt1 -> pure v | ||
r <- case u of | ||
#default @ alt2 -> pure u | ||
x <- pure u | ||
pure r | ||
|] | ||
let after = [prog| | ||
grinMain = | ||
v <- pure (CNil) | ||
u <- case v of | ||
#default @ alt1 -> pure v | ||
r <- case u of | ||
#default @ alt2 -> pure u | ||
x <- pure u | ||
pure r | ||
|] | ||
caseHoisting (inferTypeEnv before) before `sameAs` (after, NoChange) |