%
% (c) The AQUA Project, Glasgow University, 1994-1998
%
\section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
\begin{code}
module LiberateCase ( liberateCase ) where
#include "HsVersions.h"
import DynFlags
import HscTypes
import CoreLint ( showPass, endPass )
import CoreSyn
import CoreUnfold ( couldBeSmallEnoughToInline )
import Rules ( RuleBase )
import UniqSupply ( UniqSupply )
import SimplMonad ( SimplCount, zeroSimplCount )
import Id
import FamInstEnv
import Type
import Coercion
import TyCon
import VarEnv
import Name ( localiseName )
import Outputable
import Util ( notNull )
import Data.IORef ( readIORef )
\end{code}
The liberate-case transformation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This module walks over @Core@, and looks for @case@ on free variables.
The criterion is:
if there is case on a free on the route to the recursive call,
then the recursive call is replaced with an unfolding.
Example
f = \ t -> case v of
V a b -> a : f t
=> the inner f is replaced.
f = \ t -> case v of
V a b -> a : (letrec
f = \ t -> case v of
V a b -> a : f t
in f) t
(note the NEED for shadowing)
=> Simplify
f = \ t -> case v of
V a b -> a : (letrec
f = \ t -> a : f t
in f t)
Better code, because 'a' is free inside the inner letrec, rather
than needing projection from v.
Other examples we'd like to catch with this kind of transformation
last [] = error
last (x:[]) = x
last (x:xs) = last xs
We'd like to avoid the redundant pattern match, transforming to
last [] = error
last (x:[]) = x
last (x:(y:ys)) = last' y ys
where
last' y [] = y
last' _ (y:ys) = last' y ys
(is this necessarily an improvement)
Similarly drop:
drop n [] = []
drop 0 xs = xs
drop n (x:xs) = drop (n-1) xs
Would like to pass n along unboxed.
Note [Scrutinee with cast]
~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider this:
f = \ t -> case (v `cast` co) of
V a b -> a : f t
Exactly the same optimisation (unrolling one call to f) will work here,
despite the cast. See mk_alt_env in the Case branch of libCase.
To think about (Apr 94)
~~~~~~~~~~~~~~
Main worry: duplicating code excessively. At the moment we duplicate
the entire binding group once at each recursive call. But there may
be a group of recursive calls which share a common set of evaluated
free variables, in which case the duplication is a plain waste.
Another thing we could consider adding is some unfold-threshold thing,
so that we'll only duplicate if the size of the group rhss isn't too
big.
Data types
~~~~~~~~~~
The ``level'' of a binder tells how many
recursive defns lexically enclose the binding
A recursive defn "encloses" its RHS, not its
scope. For example:
\begin{verbatim}
letrec f = let g = ... in ...
in
let h = ...
in ...
\end{verbatim}
Here, the level of @f@ is zero, the level of @g@ is one,
and the level of @h@ is zero (NB not one).
Note [Indexed data types]
~~~~~~~~~~~~~~~~~~~~~~~~~
Consider
data family T :: * -> *
data T Int = TI Int
f :: T Int -> Bool
f x = case x of { DEFAULT ->
}
We would like to change this to
f x = case x `cast` co of { TI p -> }
so that can make use of the fact that x is already evaluated to
a TI; and a case on a known data type may be more efficient than a
polymorphic one (not sure this is true any longer). Anyway the former
showed up in Roman's experiments. Example:
foo :: FooT Int -> Int -> Int
foo t n = t `seq` bar n
where
bar 0 = 0
bar n = bar (n - case t of TI i -> i)
Here we'd like to avoid repeated evaluating t inside the loop, by
taking advantage of the `seq`.
We implement this as part of the liberate-case transformation by
spotting
case of (x::T) tys { DEFAULT -> }
where x :: T tys, and T is a indexed family tycon. Find the
representation type (T77 tys'), and coercion co, and transform to
case `cast` co of (y::T77 tys')
DEFAULT -> let x = y `cast` sym co in
The "find the representation type" part is done by looking up in the
family-instance environment.
NB: in fact we re-use x (changing its type) to avoid making a fresh y;
this entails shadowing, but that's ok.
%************************************************************************
%* *
Top-level code
%* *
%************************************************************************
\begin{code}
liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
-> IO (SimplCount, ModGuts)
liberateCase hsc_env _ _ guts
= do { let dflags = hsc_dflags hsc_env
; eps <- readIORef (hsc_EPS hsc_env)
; let fam_envs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
; showPass dflags "Liberate case"
; let { env = initEnv dflags fam_envs
; binds' = do_prog env (mg_binds guts) }
; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
{- no specific flag for dumping -}
; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
where
do_prog env [] = []
do_prog env (bind:binds) = bind' : do_prog env' binds
where
(env', bind') = libCaseBind env bind
\end{code}
%************************************************************************
%* *
Main payload
%* *
%************************************************************************
Bindings
~~~~~~~~
\begin{code}
libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind env (NonRec binder rhs)
= (addBinders env [binder], NonRec binder (libCase env rhs))
libCaseBind env (Rec pairs)
= (env_body, Rec pairs')
where
(binders, rhss) = unzip pairs
env_body = addBinders env binders
pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
env_rhs = if all rhs_small_enough rhss then extended_env else env
-- We extend the rec-env by binding each Id to its rhs, first
-- processing the rhs with an *un-extended* environment, so
-- that the same process doesn't occur for ever!
--
extended_env = addRecBinds env [ (adjust binder, libCase env_body rhs)
| (binder, rhs) <- pairs ]
-- Two subtle things:
-- (a) Reset the export flags on the binders so
-- that we don't get name clashes on exported things if the
-- local binding floats out to top level. This is most unlikely
-- to happen, since the whole point concerns free variables.
-- But resetting the export flag is right regardless.
--
-- (b) Make the name an Internal one. External Names should never be
-- nested; if it were floated to the top level, we'd get a name
-- clash at code generation time.
adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
rhs_small_enough rhs = couldBeSmallEnoughToInline lIBERATE_BOMB_SIZE rhs
lIBERATE_BOMB_SIZE = bombOutSize env
\end{code}
Expressions
~~~~~~~~~~~
\begin{code}
libCase :: LibCaseEnv
-> CoreExpr
-> CoreExpr
libCase env (Var v) = libCaseId env v
libCase env (Lit lit) = Lit lit
libCase env (Type ty) = Type ty
libCase env (App fun arg) = App (libCase env fun) (libCase env arg)
libCase env (Note note body) = Note note (libCase env body)
libCase env (Cast e co) = Cast (libCase env e) co
libCase env (Lam binder body)
= Lam binder (libCase (addBinders env [binder]) body)
libCase env (Let bind body)
= Let bind' (libCase env_body body)
where
(env_body, bind') = libCaseBind env bind
libCase env (Case scrut bndr ty alts)
= mkCase env (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
where
env_alts = addBinders (mk_alt_env scrut) [bndr]
mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
mk_alt_env (Cast scrut _) = mk_alt_env scrut -- Note [Scrutinee with cast]
mk_alt_env otehr = env
libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
\end{code}
\begin{code}
mkCase :: LibCaseEnv -> CoreExpr -> Id -> Type -> [CoreAlt] -> CoreExpr
-- See Note [Indexed data types]
mkCase env scrut bndr ty [(DEFAULT,_,rhs)]
| Just (tycon, tys) <- splitTyConApp_maybe (idType bndr)
, [(subst, fam_inst)] <- lookupFamInstEnv (lc_fams env) tycon tys
= let
rep_tc = famInstTyCon fam_inst
rep_tys = map (substTyVar subst) (tyConTyVars rep_tc)
bndr' = setIdType bndr (mkTyConApp rep_tc rep_tys)
Just co_tc = tyConFamilyCoercion_maybe rep_tc
co = mkTyConApp co_tc rep_tys
bind = NonRec bndr (Cast (Var bndr') (mkSymCoercion co))
in mkCase env (Cast scrut co) bndr' ty [(DEFAULT,[],Let bind rhs)]
mkCase env scrut bndr ty alts
= Case scrut bndr ty alts
\end{code}
Ids
~~~
\begin{code}
libCaseId :: LibCaseEnv -> Id -> CoreExpr
libCaseId env v
| Just the_bind <- lookupRecId env v -- It's a use of a recursive thing
, notNull free_scruts -- with free vars scrutinised in RHS
= Let the_bind (Var v)
| otherwise
= Var v
where
rec_id_level = lookupLevel env v
free_scruts = freeScruts env rec_id_level
\end{code}
%************************************************************************
%* *
Utility functions
%* *
%************************************************************************
\begin{code}
addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
= env { lc_lvl_env = lvl_env' }
where
lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env,
lc_rec_env = rec_env}) pairs
= env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
where
lvl' = lvl + 1
lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
addScrutedVar :: LibCaseEnv
-> Id -- This Id is being scrutinised by a case expression
-> LibCaseEnv
addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env,
lc_scruts = scruts }) scrut_var
| bind_lvl < lvl
= env { lc_scruts = scruts' }
-- Add to scruts iff the scrut_var is being scrutinised at
-- a deeper level than its defn
| otherwise = env
where
scruts' = (scrut_var, lvl) : scruts
bind_lvl = case lookupVarEnv lvl_env scrut_var of
Just lvl -> lvl
Nothing -> topLevel
lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
lookupRecId env id = lookupVarEnv (lc_rec_env env) id
lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
lookupLevel env id
= case lookupVarEnv (lc_lvl_env env) id of
Just lvl -> lc_lvl env
Nothing -> topLevel
freeScruts :: LibCaseEnv
-> LibCaseLevel -- Level of the recursive Id
-> [Id] -- Ids that are scrutinised between the binding
-- of the recursive Id and here
freeScruts env rec_bind_lvl
= [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl]
\end{code}
%************************************************************************
%* *
The environment
%* *
%************************************************************************
\begin{code}
type LibCaseLevel = Int
topLevel :: LibCaseLevel
topLevel = 0
\end{code}
\begin{code}
data LibCaseEnv
= LibCaseEnv {
lc_size :: Int, -- Bomb-out size for deciding if
-- potential liberatees are too big.
-- (passed in from cmd-line args)
lc_lvl :: LibCaseLevel, -- Current level
lc_lvl_env :: IdEnv LibCaseLevel,
-- Binds all non-top-level in-scope Ids
-- (top-level and imported things have
-- a level of zero)
lc_rec_env :: IdEnv CoreBind,
-- Binds *only* recursively defined ids,
-- to their own binding group,
-- and *only* in their own RHSs
lc_scruts :: [(Id,LibCaseLevel)],
-- Each of these Ids was scrutinised by an
-- enclosing case expression, with the
-- specified number of enclosing
-- recursive bindings; furthermore,
-- the Id is bound at a lower level
-- than the case expression. The order is
-- insignificant; it's a bag really
lc_fams :: FamInstEnvs
-- Instance env for indexed data types
}
initEnv :: DynFlags -> FamInstEnvs -> LibCaseEnv
initEnv dflags fams
= LibCaseEnv { lc_size = specThreshold dflags,
lc_lvl = 0,
lc_lvl_env = emptyVarEnv,
lc_rec_env = emptyVarEnv,
lc_scruts = [],
lc_fams = fams }
bombOutSize = lc_size
\end{code}