diff --git a/grin/grin.cabal b/grin/grin.cabal index 50e4c951..d4d37b7e 100644 --- a/grin/grin.cabal +++ b/grin/grin.cabal @@ -147,6 +147,7 @@ library Transformations.ExtendedSyntax.MangleNames Transformations.ExtendedSyntax.StaticSingleAssignment Transformations.ExtendedSyntax.Optimising.ArityRaising + Transformations.ExtendedSyntax.Optimising.CaseHoisting Transformations.ExtendedSyntax.Optimising.CopyPropagation Transformations.ExtendedSyntax.Optimising.ConstantPropagation Transformations.ExtendedSyntax.Optimising.CSE @@ -308,6 +309,7 @@ test-suite grin-test Transformations.ExtendedSyntax.MangleNamesSpec Transformations.ExtendedSyntax.StaticSingleAssignmentSpec Transformations.ExtendedSyntax.Optimising.ArityRaisingSpec + Transformations.ExtendedSyntax.Optimising.CaseHoistingSpec Transformations.ExtendedSyntax.Optimising.CopyPropagationSpec Transformations.ExtendedSyntax.Optimising.CSESpec Transformations.ExtendedSyntax.Optimising.EvaluatedCaseEliminationSpec diff --git a/grin/src/Transformations/ExtendedSyntax/Optimising/CaseHoisting.hs b/grin/src/Transformations/ExtendedSyntax/Optimising/CaseHoisting.hs new file mode 100644 index 00000000..f29f1090 --- /dev/null +++ b/grin/src/Transformations/ExtendedSyntax/Optimising/CaseHoisting.hs @@ -0,0 +1,123 @@ +{-# LANGUAGE LambdaCase, TupleSections #-} +module Transformations.ExtendedSyntax.Optimising.CaseHoisting where + +import Control.Monad +import Control.Comonad +import Control.Comonad.Cofree +import Data.Functor.Foldable as Foldable +import qualified Data.Foldable + +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Set (Set) +import qualified Data.Set as Set +import qualified Data.Vector as Vector +import Data.Bifunctor (first) + +import Grin.ExtendedSyntax.Grin +import Grin.ExtendedSyntax.TypeEnv +import Transformations.ExtendedSyntax.Util +import Transformations.ExtendedSyntax.Names + +{- + IDEA: + If Alt had name then the HPT could calculate it's return type and store in TypeEnv +-} + +getReturnTagSet :: TypeEnv -> Exp -> Maybe (Set Tag) +getReturnTagSet typeEnv = cata folder where + folder exp = case exp of + EBindF _ _ ts -> ts + SBlockF ts -> ts + AltF _ _ ts -> ts + ECaseF _ alts -> mconcat <$> sequence alts + + SReturnF val + | Just (T_NodeSet ns) <- mTypeOfValTE typeEnv val + -> Just (Map.keysSet ns) + + SAppF name _ + | T_NodeSet ns <- fst $ functionType typeEnv name + -> Just (Map.keysSet ns) + + SFetchF name + | T_SimpleType (T_Location locs) <- variableType typeEnv name + -> Just (mconcat [Map.keysSet (_location typeEnv Vector.! loc) | loc <- locs]) + + _ -> Nothing + + +caseHoisting :: TypeEnv -> Exp -> (Exp, ExpChanges) +caseHoisting typeEnv exp = first fst $ evalNameM exp $ histoM folder exp where + + folder :: ExpF (Cofree ExpF (Exp, Set Name)) -> NameM (Exp, Set Name) + folder exp = case exp of + -- middle case + EBindF ((ECase val alts1, leftUse) :< _) (VarPat lpatName) + (_ :< (EBindF ((ECase varName alts2, caseUse) :< _) lpat ((rightExp, rightUse) :< _))) + | lpatName == varName + , Just alts1Types <- sequence $ map (getReturnTagSet typeEnv) alts1 + , Just matchList <- disjointMatch (zip alts1Types alts1) alts2 + , Set.notMember varName rightUse -- allow only linear variables ; that are not used later + -> do + hoistedAlts <- mapM (hoistAlts lpatName) matchList + pure (EBind (ECase val hoistedAlts) lpat rightExp, Set.delete varName $ mconcat [leftUse, caseUse, rightUse]) + + -- last case + EBindF ((ECase val alts1, leftUse) :< _) (VarPat lpatName) ((ECase varName alts2, rightUse) :< _) + | lpatName == varName + , Just alts1Types <- sequence $ map (getReturnTagSet typeEnv) alts1 + , Just matchList <- disjointMatch (zip alts1Types alts1) alts2 + -> do + hoistedAlts <- mapM (hoistAlts lpatName) matchList + pure (ECase val hoistedAlts, Set.delete varName $ mconcat [leftUse, rightUse]) + + _ -> let useSub = Data.Foldable.fold (snd . extract <$> exp) + useExp = foldNameUseExpF Set.singleton exp + in pure (embed (fst . extract <$> exp), mconcat [useSub, useExp]) + +hoistAlts :: Name -> (Alt, Alt) -> NameM Alt +hoistAlts lpatName (Alt cpat1 altName1 alt1, Alt cpat2 altName2 alt2) = do + freshLPatName <- deriveNewName lpatName + let nameMap = Map.singleton lpatName freshLPatName + (freshAlt2, _) <- refreshNames nameMap $ + EBind (SReturn $ Var freshLPatName) (VarPat altName2) alt2 + pure . Alt cpat1 altName1 $ EBind (SBlock alt1) (VarPat freshLPatName) freshAlt2 + +disjointMatch :: [(Set Tag, Alt)] -> [Alt] -> Maybe [(Alt, Alt)] +disjointMatch tsAlts1 alts2 + | Just (defaults, tagMap) <- mconcat <$> mapM groupByCPats alts2 + , length defaults <= 1 + , Just (altPairs, _, _) <- Data.Foldable.foldrM (matchAlt tagMap) ([], defaults, Set.empty) tsAlts1 + = Just altPairs +disjointMatch _ _ = Nothing + +groupByCPats :: Alt -> Maybe ([Alt], Map Tag Alt) +groupByCPats alt@(Alt cpat _ _) = case cpat of + DefaultPat -> Just ([alt], mempty) + NodePat tag _ -> Just ([], Map.singleton tag alt) + _ -> Nothing + +matchAlt :: Map Tag Alt -> (Set Tag, Alt) -> ([(Alt, Alt)], [Alt], Set Tag) -> Maybe ([(Alt, Alt)], [Alt], Set Tag) +matchAlt tagMap (ts, alt1) (matchList, defaults, coveredTags) + -- regular node pattern + | Set.size ts == 1 + , tag <- Set.findMin ts + , Set.notMember tag coveredTags + , Just alt2 <- Map.lookup tag tagMap + = Just ((alt1, alt2):matchList, defaults, Set.insert tag coveredTags) + + -- default can handle this + | defaultAlt:[] <- defaults + , Data.Foldable.all (flip Set.notMember coveredTags) ts + = Just ((alt1, defaultAlt):matchList, [], coveredTags `mappend` ts) + + | otherwise = Nothing + +{- + TODO: + - add cloned variables to TypeEnv + done - ignore non linear scrutinee + IDEA: + this could be supported if product type was available in GRIN then the second case could return from the hoisted case with a pair of the original two case results +-} diff --git a/grin/test/Transformations/ExtendedSyntax/Optimising/CaseHoistingSpec.hs b/grin/test/Transformations/ExtendedSyntax/Optimising/CaseHoistingSpec.hs new file mode 100644 index 00000000..f161b63f --- /dev/null +++ b/grin/test/Transformations/ExtendedSyntax/Optimising/CaseHoistingSpec.hs @@ -0,0 +1,203 @@ +{-# LANGUAGE OverloadedStrings, QuasiQuotes, ViewPatterns #-} +module Transformations.ExtendedSyntax.Optimising.CaseHoistingSpec where + +import Transformations.ExtendedSyntax.Optimising.CaseHoisting + +import Test.Hspec + +import Grin.ExtendedSyntax.TH +import Grin.ExtendedSyntax.TypeCheck +import Test.ExtendedSyntax.Assertions +import Transformations.ExtendedSyntax.Names (ExpChanges(..)) + + +runTests :: IO () +runTests = hspec spec + +spec :: Spec +spec = do + it "last case" $ do + let before = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + (CNil) @ alt1 -> pure (CNil) + (CCons a1 b1) @ alt2 -> pure (CCons a1 b1) + case u of + (CNil) @ alt3 -> pure alt3 + (CCons a2 b2) @ alt4 -> pure (CNil) + |] + let after = [prog| + grinMain = + v <- pure (CNil) + case v of + (CNil) @ alt1 -> + u.0 <- do + pure (CNil) + alt3.0 <- pure u.0 + pure alt3.0 + (CCons a1 b1) @ alt2 -> + u.1 <- do + pure (CCons a1 b1) + alt4.0 <- pure u.1 + pure (CNil) + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "middle case" $ do + let before = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + (CNil) @ alt1 -> pure (CNil) + (CCons a1 b1) @ alt2 -> pure (CCons a1 b1) + r <- case u of + (CNil) @ alt3 -> pure 1 + (CCons a2 b2) @ alt4 -> pure 2 + pure r + |] + let after = [prog| + grinMain = + v <- pure (CNil) + r <- case v of + (CNil) @ alt1 -> + u.0 <- do + pure (CNil) + alt3.0 <- pure u.0 + pure 1 + (CCons a1 b1) @ alt2 -> + u.1 <- do + pure (CCons a1 b1) + alt4.0 <- pure u.1 + pure 2 + pure r + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "default pattern" $ do + let before = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + (CNil) @ alt1 -> pure (CNil) + (CCons a1 b1) @ alt2 -> pure (CCons a1 b1) + r <- case u of + (CNil) @ alt3 -> pure (CNil) + #default @ alt4 -> pure alt4 + pure r + |] + let after = [prog| + grinMain = + v <- pure (CNil) + r <- case v of + (CNil) @ alt1 -> + u.0 <- do + pure (CNil) + alt3.0 <- pure u.0 + pure (CNil) + (CCons a1 b1) @ alt2 -> + u.1 <- do + pure (CCons a1 b1) + alt4.0 <- pure u.1 + pure alt4.0 + pure r + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "case chain + no code duplication" $ do + let before = [prog| + grinMain = + v <- pure 1 + u <- case v of + 0 @ alt1 -> pure (CNil) + 1 @ alt2 -> pure (CCons v v) + r <- case u of + (CNil) @ alt3 -> pure (CEmpty) + #default @ alt4 -> pure u + q <- case r of + (CVoid) @ alt5 -> + pure (CEmpty) + #default @ alt6 -> + k0 <- pure 777 + _1 <- _prim_int_print k0 + pure r + pure q + |] + let after = [prog| + grinMain = + v <- pure 1 + r <- case v of + 0 @ alt1 -> + u.0 <- do + pure (CNil) + alt3.0 <- pure u.0 + pure (CEmpty) + 1 @ alt2 -> + u.1 <- do + pure (CCons v v) + alt4.0 <- pure u.1 + pure u.1 + q <- case r of + (CVoid) @ alt5 -> + pure (CEmpty) + #default @ alt6 -> + k0 <- pure 777 + _1 <- _prim_int_print k0 + pure r + pure q + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "default chain" $ do + let before = [prog| + grinMain = + v <- pure 1 + u <- case v of + 0 @ alt1 -> pure (CNil) + 1 @ alt2 -> pure (CCons v v) + r <- case u of + #default @ alt3 -> pure u + q <- case r of + #default @ alt4 -> pure r + pure q + |] + let after = [prog| + grinMain = + v <- pure 1 + u <- case v of + 0 @ alt1 -> + pure (CNil) + 1 @ alt2 -> + pure (CCons v v) + q <- case u of + #default @ alt3 -> + r.0 <- do + pure u + alt4.0 <- pure r.0 + pure r.0 + pure q + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NewNames) + + it "ignore non linear variable" $ do + let before = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + #default @ alt1 -> pure v + r <- case u of + #default @ alt2 -> pure u + x <- pure u + pure r + |] + let after = [prog| + grinMain = + v <- pure (CNil) + u <- case v of + #default @ alt1 -> pure v + r <- case u of + #default @ alt2 -> pure u + x <- pure u + pure r + |] + caseHoisting (inferTypeEnv before) before `sameAs` (after, NoChange)