module GHC.Core.Opt.LiberateCase ( liberateCase ) where
import GHC.Prelude
import GHC.Driver.Session
import GHC.Core
import GHC.Core.Unfold
import GHC.Builtin.Types ( unitDataConId )
import GHC.Types.Id
import GHC.Types.Var.Env
import GHC.Utils.Misc    ( notNull )
liberateCase :: DynFlags -> CoreProgram -> CoreProgram
liberateCase :: DynFlags -> CoreProgram -> CoreProgram
liberateCase DynFlags
dflags CoreProgram
binds = LibCaseEnv -> CoreProgram -> CoreProgram
do_prog (DynFlags -> LibCaseEnv
initLiberateCaseEnv DynFlags
dflags) CoreProgram
binds
  where
    do_prog :: LibCaseEnv -> CoreProgram -> CoreProgram
do_prog LibCaseEnv
_   [] = []
    do_prog LibCaseEnv
env (CoreBind
bind:CoreProgram
binds) = CoreBind
bind' CoreBind -> CoreProgram -> CoreProgram
forall a. a -> [a] -> [a]
: LibCaseEnv -> CoreProgram -> CoreProgram
do_prog LibCaseEnv
env' CoreProgram
binds
                             where
                               (LibCaseEnv
env', CoreBind
bind') = LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind LibCaseEnv
env CoreBind
bind
initLiberateCaseEnv :: DynFlags -> LibCaseEnv
initLiberateCaseEnv :: DynFlags -> LibCaseEnv
initLiberateCaseEnv DynFlags
dflags = LibCaseEnv
   { lc_threshold :: Maybe LibCaseLevel
lc_threshold = DynFlags -> Maybe LibCaseLevel
liberateCaseThreshold DynFlags
dflags
   , lc_uf_opts :: UnfoldingOpts
lc_uf_opts   = DynFlags -> UnfoldingOpts
unfoldingOpts DynFlags
dflags
   , lc_lvl :: LibCaseLevel
lc_lvl       = LibCaseLevel
0
   , lc_lvl_env :: IdEnv LibCaseLevel
lc_lvl_env   = IdEnv LibCaseLevel
forall a. VarEnv a
emptyVarEnv
   , lc_rec_env :: IdEnv CoreBind
lc_rec_env   = IdEnv CoreBind
forall a. VarEnv a
emptyVarEnv
   , lc_scruts :: [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts    = []
   }
libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind LibCaseEnv
env (NonRec Id
binder Expr Id
rhs)
  = (LibCaseEnv -> [Id] -> LibCaseEnv
addBinders LibCaseEnv
env [Id
binder], Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
binder (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
rhs))
libCaseBind LibCaseEnv
env (Rec [(Id, Expr Id)]
pairs)
  = (LibCaseEnv
env_body, [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
pairs')
  where
    binders :: [Id]
binders = ((Id, Expr Id) -> Id) -> [(Id, Expr Id)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Expr Id) -> Id
forall a b. (a, b) -> a
fst [(Id, Expr Id)]
pairs
    env_body :: LibCaseEnv
env_body = LibCaseEnv -> [Id] -> LibCaseEnv
addBinders LibCaseEnv
env [Id]
binders
    pairs' :: [(Id, Expr Id)]
pairs' = [(Id
binder, LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env_rhs Expr Id
rhs) | (Id
binder,Expr Id
rhs) <- [(Id, Expr Id)]
pairs]
        
        
        
    env_rhs :: LibCaseEnv
env_rhs | Bool
is_dupable_bind = LibCaseEnv -> [(Id, Expr Id)] -> LibCaseEnv
addRecBinds LibCaseEnv
env [(Id, Expr Id)]
dup_pairs
            | Bool
otherwise       = LibCaseEnv
env
    dup_pairs :: [(Id, Expr Id)]
dup_pairs = [ (Id -> Id
localiseId Id
binder, LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env_body Expr Id
rhs)
                | (Id
binder, Expr Id
rhs) <- [(Id, Expr Id)]
pairs ]
        
    is_dupable_bind :: Bool
is_dupable_bind = Bool
small_enough Bool -> Bool -> Bool
&& ((Id, Expr Id) -> Bool) -> [(Id, Expr Id)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Id, Expr Id) -> Bool
forall {b}. (Id, b) -> Bool
ok_pair [(Id, Expr Id)]
pairs
    
    
    
    
    small_enough :: Bool
small_enough = case LibCaseEnv -> Maybe LibCaseLevel
lc_threshold LibCaseEnv
env of
                      Maybe LibCaseLevel
Nothing   -> Bool
True   
                      Just LibCaseLevel
size -> UnfoldingOpts -> LibCaseLevel -> Expr Id -> Bool
couldBeSmallEnoughToInline (LibCaseEnv -> UnfoldingOpts
lc_uf_opts LibCaseEnv
env) LibCaseLevel
size (Expr Id -> Bool) -> Expr Id -> Bool
forall a b. (a -> b) -> a -> b
$
                                   CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let ([(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
dup_pairs) (Id -> Expr Id
forall b. Id -> Expr b
Var Id
unitDataConId)
    ok_pair :: (Id, b) -> Bool
ok_pair (Id
id,b
_)
        =  Id -> LibCaseLevel
idArity Id
id LibCaseLevel -> LibCaseLevel -> Bool
forall a. Ord a => a -> a -> Bool
> LibCaseLevel
0       
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Id -> Bool
isDeadEndId Id
id) 
libCase :: LibCaseEnv
        -> CoreExpr
        -> CoreExpr
libCase :: LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env (Var Id
v)             = LibCaseEnv -> Id -> [Expr Id] -> Expr Id
libCaseApp LibCaseEnv
env Id
v []
libCase LibCaseEnv
_   (Lit Literal
lit)           = Literal -> Expr Id
forall b. Literal -> Expr b
Lit Literal
lit
libCase LibCaseEnv
_   (Type Type
ty)           = Type -> Expr Id
forall b. Type -> Expr b
Type Type
ty
libCase LibCaseEnv
_   (Coercion Coercion
co)       = Coercion -> Expr Id
forall b. Coercion -> Expr b
Coercion Coercion
co
libCase LibCaseEnv
env e :: Expr Id
e@(App {})          | let (Expr Id
fun, [Expr Id]
args) = Expr Id -> (Expr Id, [Expr Id])
forall b. Expr b -> (Expr b, [Expr b])
collectArgs Expr Id
e
                                , Var Id
v <- Expr Id
fun
                                = LibCaseEnv -> Id -> [Expr Id] -> Expr Id
libCaseApp LibCaseEnv
env Id
v [Expr Id]
args
libCase LibCaseEnv
env (App Expr Id
fun Expr Id
arg)       = Expr Id -> Expr Id -> Expr Id
forall b. Expr b -> Expr b -> Expr b
App (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
fun) (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
arg)
libCase LibCaseEnv
env (Tick CoreTickish
tickish Expr Id
body) = CoreTickish -> Expr Id -> Expr Id
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
body)
libCase LibCaseEnv
env (Cast Expr Id
e Coercion
co)         = Expr Id -> Coercion -> Expr Id
forall b. Expr b -> Coercion -> Expr b
Cast (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
e) Coercion
co
libCase LibCaseEnv
env (Lam Id
binder Expr Id
body)
  = Id -> Expr Id -> Expr Id
forall b. b -> Expr b -> Expr b
Lam Id
binder (LibCaseEnv -> Expr Id -> Expr Id
libCase (LibCaseEnv -> [Id] -> LibCaseEnv
addBinders LibCaseEnv
env [Id
binder]) Expr Id
body)
libCase LibCaseEnv
env (Let CoreBind
bind Expr Id
body)
  = CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let CoreBind
bind' (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env_body Expr Id
body)
  where
    (LibCaseEnv
env_body, CoreBind
bind') = LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind LibCaseEnv
env CoreBind
bind
libCase LibCaseEnv
env (Case Expr Id
scrut Id
bndr Type
ty [Alt Id]
alts)
  = Expr Id -> Id -> Type -> [Alt Id] -> Expr Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
scrut) Id
bndr Type
ty ((Alt Id -> Alt Id) -> [Alt Id] -> [Alt Id]
forall a b. (a -> b) -> [a] -> [b]
map (LibCaseEnv -> Alt Id -> Alt Id
libCaseAlt LibCaseEnv
env_alts) [Alt Id]
alts)
  where
    env_alts :: LibCaseEnv
env_alts = LibCaseEnv -> [Id] -> LibCaseEnv
addBinders (Expr Id -> LibCaseEnv
mk_alt_env Expr Id
scrut) [Id
bndr]
    mk_alt_env :: Expr Id -> LibCaseEnv
mk_alt_env (Var Id
scrut_var) = LibCaseEnv -> Id -> LibCaseEnv
addScrutedVar LibCaseEnv
env Id
scrut_var
    mk_alt_env (Cast Expr Id
scrut Coercion
_)  = Expr Id -> LibCaseEnv
mk_alt_env Expr Id
scrut       
    mk_alt_env Expr Id
_               = LibCaseEnv
env
libCaseAlt :: LibCaseEnv -> Alt CoreBndr -> Alt CoreBndr
libCaseAlt :: LibCaseEnv -> Alt Id -> Alt Id
libCaseAlt LibCaseEnv
env (Alt AltCon
con [Id]
args Expr Id
rhs) = AltCon -> [Id] -> Expr Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
con [Id]
args (LibCaseEnv -> Expr Id -> Expr Id
libCase (LibCaseEnv -> [Id] -> LibCaseEnv
addBinders LibCaseEnv
env [Id]
args) Expr Id
rhs)
libCaseApp :: LibCaseEnv -> Id -> [CoreExpr] -> CoreExpr
libCaseApp :: LibCaseEnv -> Id -> [Expr Id] -> Expr Id
libCaseApp LibCaseEnv
env Id
v [Expr Id]
args
  | Just CoreBind
the_bind <- LibCaseEnv -> Id -> Maybe CoreBind
lookupRecId LibCaseEnv
env Id
v  
  , [Id] -> Bool
forall (f :: * -> *) a. Foldable f => f a -> Bool
notNull [Id]
free_scruts                 
  = CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let CoreBind
the_bind Expr Id
expr'
  | Bool
otherwise
  = Expr Id
expr'
  where
    rec_id_level :: LibCaseLevel
rec_id_level = LibCaseEnv -> Id -> LibCaseLevel
lookupLevel LibCaseEnv
env Id
v
    free_scruts :: [Id]
free_scruts  = LibCaseEnv -> LibCaseLevel -> [Id]
freeScruts LibCaseEnv
env LibCaseLevel
rec_id_level
    expr' :: Expr Id
expr'        = Expr Id -> [Expr Id] -> Expr Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps (Id -> Expr Id
forall b. Id -> Expr b
Var Id
v) ((Expr Id -> Expr Id) -> [Expr Id] -> [Expr Id]
forall a b. (a -> b) -> [a] -> [b]
map (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env) [Expr Id]
args)
freeScruts :: LibCaseEnv
           -> LibCaseLevel      
           -> [Id]              
                                
freeScruts :: LibCaseEnv -> LibCaseLevel -> [Id]
freeScruts LibCaseEnv
env LibCaseLevel
rec_bind_lvl
  = [Id
v | (Id
v, LibCaseLevel
scrut_bind_lvl, LibCaseLevel
scrut_at_lvl) <- LibCaseEnv -> [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts LibCaseEnv
env
       , LibCaseLevel
scrut_bind_lvl LibCaseLevel -> LibCaseLevel -> Bool
forall a. Ord a => a -> a -> Bool
<= LibCaseLevel
rec_bind_lvl
       , LibCaseLevel
scrut_at_lvl LibCaseLevel -> LibCaseLevel -> Bool
forall a. Ord a => a -> a -> Bool
> LibCaseLevel
rec_bind_lvl]
        
        
addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders :: LibCaseEnv -> [Id] -> LibCaseEnv
addBinders env :: LibCaseEnv
env@(LibCaseEnv { lc_lvl :: LibCaseEnv -> LibCaseLevel
lc_lvl = LibCaseLevel
lvl, lc_lvl_env :: LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env }) [Id]
binders
  = LibCaseEnv
env { lc_lvl_env :: IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env' }
  where
    lvl_env' :: IdEnv LibCaseLevel
lvl_env' = IdEnv LibCaseLevel -> [(Id, LibCaseLevel)] -> IdEnv LibCaseLevel
forall a. VarEnv a -> [(Id, a)] -> VarEnv a
extendVarEnvList IdEnv LibCaseLevel
lvl_env ([Id]
binders [Id] -> [LibCaseLevel] -> [(Id, LibCaseLevel)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` LibCaseLevel -> [LibCaseLevel]
forall a. a -> [a]
repeat LibCaseLevel
lvl)
addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
addRecBinds :: LibCaseEnv -> [(Id, Expr Id)] -> LibCaseEnv
addRecBinds env :: LibCaseEnv
env@(LibCaseEnv {lc_lvl :: LibCaseEnv -> LibCaseLevel
lc_lvl = LibCaseLevel
lvl, lc_lvl_env :: LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env,
                             lc_rec_env :: LibCaseEnv -> IdEnv CoreBind
lc_rec_env = IdEnv CoreBind
rec_env}) [(Id, Expr Id)]
pairs
  = LibCaseEnv
env { lc_lvl :: LibCaseLevel
lc_lvl = LibCaseLevel
lvl', lc_lvl_env :: IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env', lc_rec_env :: IdEnv CoreBind
lc_rec_env = IdEnv CoreBind
rec_env' }
  where
    lvl' :: LibCaseLevel
lvl'     = LibCaseLevel
lvl LibCaseLevel -> LibCaseLevel -> LibCaseLevel
forall a. Num a => a -> a -> a
+ LibCaseLevel
1
    lvl_env' :: IdEnv LibCaseLevel
lvl_env' = IdEnv LibCaseLevel -> [(Id, LibCaseLevel)] -> IdEnv LibCaseLevel
forall a. VarEnv a -> [(Id, a)] -> VarEnv a
extendVarEnvList IdEnv LibCaseLevel
lvl_env [(Id
binder,LibCaseLevel
lvl) | (Id
binder,Expr Id
_) <- [(Id, Expr Id)]
pairs]
    rec_env' :: IdEnv CoreBind
rec_env' = IdEnv CoreBind -> [(Id, CoreBind)] -> IdEnv CoreBind
forall a. VarEnv a -> [(Id, a)] -> VarEnv a
extendVarEnvList IdEnv CoreBind
rec_env [(Id
binder, [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
pairs) | (Id
binder,Expr Id
_) <- [(Id, Expr Id)]
pairs]
addScrutedVar :: LibCaseEnv
              -> Id             
              -> LibCaseEnv
addScrutedVar :: LibCaseEnv -> Id -> LibCaseEnv
addScrutedVar env :: LibCaseEnv
env@(LibCaseEnv { lc_lvl :: LibCaseEnv -> LibCaseLevel
lc_lvl = LibCaseLevel
lvl, lc_lvl_env :: LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env,
                                lc_scruts :: LibCaseEnv -> [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts = [(Id, LibCaseLevel, LibCaseLevel)]
scruts }) Id
scrut_var
  | LibCaseLevel
bind_lvl LibCaseLevel -> LibCaseLevel -> Bool
forall a. Ord a => a -> a -> Bool
< LibCaseLevel
lvl
  = LibCaseEnv
env { lc_scruts :: [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts = [(Id, LibCaseLevel, LibCaseLevel)]
scruts' }
        
        
  | Bool
otherwise = LibCaseEnv
env
  where
    scruts' :: [(Id, LibCaseLevel, LibCaseLevel)]
scruts'  = (Id
scrut_var, LibCaseLevel
bind_lvl, LibCaseLevel
lvl) (Id, LibCaseLevel, LibCaseLevel)
-> [(Id, LibCaseLevel, LibCaseLevel)]
-> [(Id, LibCaseLevel, LibCaseLevel)]
forall a. a -> [a] -> [a]
: [(Id, LibCaseLevel, LibCaseLevel)]
scruts
    bind_lvl :: LibCaseLevel
bind_lvl = case IdEnv LibCaseLevel -> Id -> Maybe LibCaseLevel
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv LibCaseLevel
lvl_env Id
scrut_var of
                 Just LibCaseLevel
lvl -> LibCaseLevel
lvl
                 Maybe LibCaseLevel
Nothing  -> LibCaseLevel
topLevel
lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
lookupRecId LibCaseEnv
env Id
id = IdEnv CoreBind -> Id -> Maybe CoreBind
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv (LibCaseEnv -> IdEnv CoreBind
lc_rec_env LibCaseEnv
env) Id
id
lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
lookupLevel LibCaseEnv
env Id
id
  = case IdEnv LibCaseLevel -> Id -> Maybe LibCaseLevel
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv (LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env LibCaseEnv
env) Id
id of
      Just LibCaseLevel
lvl -> LibCaseLevel
lvl
      Maybe LibCaseLevel
Nothing  -> LibCaseLevel
topLevel
type LibCaseLevel = Int
topLevel :: LibCaseLevel
topLevel :: LibCaseLevel
topLevel = LibCaseLevel
0
data LibCaseEnv
  = LibCaseEnv {
        LibCaseEnv -> Maybe LibCaseLevel
lc_threshold :: Maybe Int,
                
                
        LibCaseEnv -> UnfoldingOpts
lc_uf_opts :: UnfoldingOpts,
                
        LibCaseEnv -> LibCaseLevel
lc_lvl :: LibCaseLevel, 
                
                
                
        LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env :: IdEnv LibCaseLevel,
                
                
        LibCaseEnv -> IdEnv CoreBind
lc_rec_env :: IdEnv CoreBind,
                
                
        LibCaseEnv -> [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts :: [(Id, LibCaseLevel, LibCaseLevel)]
                
                
                
                
                
                
                
                
                
                
                
                
                
                
                
                
                
        }