Skip to content

Commit

Permalink
Merge pull request #90 from grin-compiler/32-trf-case-hoisting
Browse files Browse the repository at this point in the history
Extended Syntax: case hoisting
  • Loading branch information
Anabra authored Apr 19, 2020
2 parents da448af + 9bfc107 commit 0c622bd
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 0 deletions.
2 changes: 2 additions & 0 deletions grin/grin.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ library
Transformations.ExtendedSyntax.MangleNames
Transformations.ExtendedSyntax.StaticSingleAssignment
Transformations.ExtendedSyntax.Optimising.ArityRaising
Transformations.ExtendedSyntax.Optimising.CaseHoisting
Transformations.ExtendedSyntax.Optimising.CopyPropagation
Transformations.ExtendedSyntax.Optimising.ConstantPropagation
Transformations.ExtendedSyntax.Optimising.CSE
Expand Down Expand Up @@ -308,6 +309,7 @@ test-suite grin-test
Transformations.ExtendedSyntax.MangleNamesSpec
Transformations.ExtendedSyntax.StaticSingleAssignmentSpec
Transformations.ExtendedSyntax.Optimising.ArityRaisingSpec
Transformations.ExtendedSyntax.Optimising.CaseHoistingSpec
Transformations.ExtendedSyntax.Optimising.CopyPropagationSpec
Transformations.ExtendedSyntax.Optimising.CSESpec
Transformations.ExtendedSyntax.Optimising.EvaluatedCaseEliminationSpec
Expand Down
123 changes: 123 additions & 0 deletions grin/src/Transformations/ExtendedSyntax/Optimising/CaseHoisting.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
{-# LANGUAGE LambdaCase, TupleSections #-}
module Transformations.ExtendedSyntax.Optimising.CaseHoisting where

import Control.Monad
import Control.Comonad
import Control.Comonad.Cofree
import Data.Functor.Foldable as Foldable
import qualified Data.Foldable

import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import qualified Data.Vector as Vector
import Data.Bifunctor (first)

import Grin.ExtendedSyntax.Grin
import Grin.ExtendedSyntax.TypeEnv
import Transformations.ExtendedSyntax.Util
import Transformations.ExtendedSyntax.Names

{-
IDEA:
If Alt had name then the HPT could calculate it's return type and store in TypeEnv
-}

getReturnTagSet :: TypeEnv -> Exp -> Maybe (Set Tag)
getReturnTagSet typeEnv = cata folder where
folder exp = case exp of
EBindF _ _ ts -> ts
SBlockF ts -> ts
AltF _ _ ts -> ts
ECaseF _ alts -> mconcat <$> sequence alts

SReturnF val
| Just (T_NodeSet ns) <- mTypeOfValTE typeEnv val
-> Just (Map.keysSet ns)

SAppF name _
| T_NodeSet ns <- fst $ functionType typeEnv name
-> Just (Map.keysSet ns)

SFetchF name
| T_SimpleType (T_Location locs) <- variableType typeEnv name
-> Just (mconcat [Map.keysSet (_location typeEnv Vector.! loc) | loc <- locs])

_ -> Nothing


caseHoisting :: TypeEnv -> Exp -> (Exp, ExpChanges)
caseHoisting typeEnv exp = first fst $ evalNameM exp $ histoM folder exp where

folder :: ExpF (Cofree ExpF (Exp, Set Name)) -> NameM (Exp, Set Name)
folder exp = case exp of
-- middle case
EBindF ((ECase val alts1, leftUse) :< _) (VarPat lpatName)
(_ :< (EBindF ((ECase varName alts2, caseUse) :< _) lpat ((rightExp, rightUse) :< _)))
| lpatName == varName
, Just alts1Types <- sequence $ map (getReturnTagSet typeEnv) alts1
, Just matchList <- disjointMatch (zip alts1Types alts1) alts2
, Set.notMember varName rightUse -- allow only linear variables ; that are not used later
-> do
hoistedAlts <- mapM (hoistAlts lpatName) matchList
pure (EBind (ECase val hoistedAlts) lpat rightExp, Set.delete varName $ mconcat [leftUse, caseUse, rightUse])

-- last case
EBindF ((ECase val alts1, leftUse) :< _) (VarPat lpatName) ((ECase varName alts2, rightUse) :< _)
| lpatName == varName
, Just alts1Types <- sequence $ map (getReturnTagSet typeEnv) alts1
, Just matchList <- disjointMatch (zip alts1Types alts1) alts2
-> do
hoistedAlts <- mapM (hoistAlts lpatName) matchList
pure (ECase val hoistedAlts, Set.delete varName $ mconcat [leftUse, rightUse])

_ -> let useSub = Data.Foldable.fold (snd . extract <$> exp)
useExp = foldNameUseExpF Set.singleton exp
in pure (embed (fst . extract <$> exp), mconcat [useSub, useExp])

hoistAlts :: Name -> (Alt, Alt) -> NameM Alt
hoistAlts lpatName (Alt cpat1 altName1 alt1, Alt cpat2 altName2 alt2) = do
freshLPatName <- deriveNewName lpatName
let nameMap = Map.singleton lpatName freshLPatName
(freshAlt2, _) <- refreshNames nameMap $
EBind (SReturn $ Var freshLPatName) (VarPat altName2) alt2
pure . Alt cpat1 altName1 $ EBind (SBlock alt1) (VarPat freshLPatName) freshAlt2

disjointMatch :: [(Set Tag, Alt)] -> [Alt] -> Maybe [(Alt, Alt)]
disjointMatch tsAlts1 alts2
| Just (defaults, tagMap) <- mconcat <$> mapM groupByCPats alts2
, length defaults <= 1
, Just (altPairs, _, _) <- Data.Foldable.foldrM (matchAlt tagMap) ([], defaults, Set.empty) tsAlts1
= Just altPairs
disjointMatch _ _ = Nothing

groupByCPats :: Alt -> Maybe ([Alt], Map Tag Alt)
groupByCPats alt@(Alt cpat _ _) = case cpat of
DefaultPat -> Just ([alt], mempty)
NodePat tag _ -> Just ([], Map.singleton tag alt)
_ -> Nothing

matchAlt :: Map Tag Alt -> (Set Tag, Alt) -> ([(Alt, Alt)], [Alt], Set Tag) -> Maybe ([(Alt, Alt)], [Alt], Set Tag)
matchAlt tagMap (ts, alt1) (matchList, defaults, coveredTags)
-- regular node pattern
| Set.size ts == 1
, tag <- Set.findMin ts
, Set.notMember tag coveredTags
, Just alt2 <- Map.lookup tag tagMap
= Just ((alt1, alt2):matchList, defaults, Set.insert tag coveredTags)

-- default can handle this
| defaultAlt:[] <- defaults
, Data.Foldable.all (flip Set.notMember coveredTags) ts
= Just ((alt1, defaultAlt):matchList, [], coveredTags `mappend` ts)

| otherwise = Nothing

{-
TODO:
- add cloned variables to TypeEnv
done - ignore non linear scrutinee
IDEA:
this could be supported if product type was available in GRIN then the second case could return from the hoisted case with a pair of the original two case results
-}
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-}
module Transformations.ExtendedSyntax.Optimising.CaseHoistingSpec where

import Transformations.ExtendedSyntax.Optimising.CaseHoisting

import Test.Hspec

import Grin.ExtendedSyntax.TH
import Grin.ExtendedSyntax.TypeCheck
import Test.ExtendedSyntax.Assertions
import Transformations.ExtendedSyntax.Names (ExpChanges(..))


runTests :: IO ()
runTests = hspec spec

spec :: Spec
spec = do
it "last case" $ do
let before = [prog|
grinMain =
v <- pure (CNil)
u <- case v of
(CNil) @ alt1 -> pure (CNil)
(CCons a1 b1) @ alt2 -> pure (CCons a1 b1)
case u of
(CNil) @ alt3 -> pure alt3
(CCons a2 b2) @ alt4 -> pure (CNil)
|]
let after = [prog|
grinMain =
v <- pure (CNil)
case v of
(CNil) @ alt1 ->
u.0 <- do
pure (CNil)
alt3.0 <- pure u.0
pure alt3.0
(CCons a1 b1) @ alt2 ->
u.1 <- do
pure (CCons a1 b1)
alt4.0 <- pure u.1
pure (CNil)
|]
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames)

it "middle case" $ do
let before = [prog|
grinMain =
v <- pure (CNil)
u <- case v of
(CNil) @ alt1 -> pure (CNil)
(CCons a1 b1) @ alt2 -> pure (CCons a1 b1)
r <- case u of
(CNil) @ alt3 -> pure 1
(CCons a2 b2) @ alt4 -> pure 2
pure r
|]
let after = [prog|
grinMain =
v <- pure (CNil)
r <- case v of
(CNil) @ alt1 ->
u.0 <- do
pure (CNil)
alt3.0 <- pure u.0
pure 1
(CCons a1 b1) @ alt2 ->
u.1 <- do
pure (CCons a1 b1)
alt4.0 <- pure u.1
pure 2
pure r
|]
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames)

it "default pattern" $ do
let before = [prog|
grinMain =
v <- pure (CNil)
u <- case v of
(CNil) @ alt1 -> pure (CNil)
(CCons a1 b1) @ alt2 -> pure (CCons a1 b1)
r <- case u of
(CNil) @ alt3 -> pure (CNil)
#default @ alt4 -> pure alt4
pure r
|]
let after = [prog|
grinMain =
v <- pure (CNil)
r <- case v of
(CNil) @ alt1 ->
u.0 <- do
pure (CNil)
alt3.0 <- pure u.0
pure (CNil)
(CCons a1 b1) @ alt2 ->
u.1 <- do
pure (CCons a1 b1)
alt4.0 <- pure u.1
pure alt4.0
pure r
|]
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames)

it "case chain + no code duplication" $ do
let before = [prog|
grinMain =
v <- pure 1
u <- case v of
0 @ alt1 -> pure (CNil)
1 @ alt2 -> pure (CCons v v)
r <- case u of
(CNil) @ alt3 -> pure (CEmpty)
#default @ alt4 -> pure u
q <- case r of
(CVoid) @ alt5 ->
pure (CEmpty)
#default @ alt6 ->
k0 <- pure 777
_1 <- _prim_int_print k0
pure r
pure q
|]
let after = [prog|
grinMain =
v <- pure 1
r <- case v of
0 @ alt1 ->
u.0 <- do
pure (CNil)
alt3.0 <- pure u.0
pure (CEmpty)
1 @ alt2 ->
u.1 <- do
pure (CCons v v)
alt4.0 <- pure u.1
pure u.1
q <- case r of
(CVoid) @ alt5 ->
pure (CEmpty)
#default @ alt6 ->
k0 <- pure 777
_1 <- _prim_int_print k0
pure r
pure q
|]
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames)

it "default chain" $ do
let before = [prog|
grinMain =
v <- pure 1
u <- case v of
0 @ alt1 -> pure (CNil)
1 @ alt2 -> pure (CCons v v)
r <- case u of
#default @ alt3 -> pure u
q <- case r of
#default @ alt4 -> pure r
pure q
|]
let after = [prog|
grinMain =
v <- pure 1
u <- case v of
0 @ alt1 ->
pure (CNil)
1 @ alt2 ->
pure (CCons v v)
q <- case u of
#default @ alt3 ->
r.0 <- do
pure u
alt4.0 <- pure r.0
pure r.0
pure q
|]
caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames)

it "ignore non linear variable" $ do
let before = [prog|
grinMain =
v <- pure (CNil)
u <- case v of
#default @ alt1 -> pure v
r <- case u of
#default @ alt2 -> pure u
x <- pure u
pure r
|]
let after = [prog|
grinMain =
v <- pure (CNil)
u <- case v of
#default @ alt1 -> pure v
r <- case u of
#default @ alt2 -> pure u
x <- pure u
pure r
|]
caseHoisting (inferTypeEnv before) before `sameAs` (after, NoChange)

0 comments on commit 0c622bd

Please sign in to comment.