{-|
Copyright   : (C) 2020-2021, QBayLogic B.V.,
                  2022     , Google Inc.
License     : BSD2 (see the file LICENSE)
Maintainer  : QBayLogic B.V. <devops@qbaylogic.com>

This module provides the "evaluation" part of the partial evaluator. This
is implemented in the classic "eval/apply" style, with a variant of apply for
performing type applications.
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE OverloadedStrings #-}

module Clash.GHC.PartialEval.Eval
  ( eval
  , apply
  , applyTy
  ) where

import           Control.Monad (foldM)
import           Data.Bifunctor
import           Data.Bitraversable
import           Data.Either
import           Data.Maybe
import           Data.Primitive.ByteArray (ByteArray(..))
#if MIN_VERSION_base(4,15,0)
import           GHC.Num.Integer (Integer (..))
#else
import           GHC.Integer.GMP.Internals (BigNat(..), Integer(..))
#endif

import           GHC.BasicTypes.Extra (isNoInline)

import           Clash.Core.DataCon (DataCon(..))
import           Clash.Core.HasType
import           Clash.Core.Literal (Literal(..))
import           Clash.Core.PartialEval.AsTerm
import           Clash.Core.PartialEval.Monad
import           Clash.Core.PartialEval.NormalForm
import           Clash.Core.Subst (substTy)
import           Clash.Core.Term
import           Clash.Core.TyCon (tyConDataCons)
import           Clash.Core.Type
import           Clash.Core.TysPrim (integerPrimTy)
import           Clash.Core.Var
import qualified Clash.Data.UniqMap as UniqMap
import           Clash.Driver.Types (Binding(..), IsPrim(..))
import qualified Clash.Normalize.Primitives as NP (undefined, undefinedX)

-- | Evaluate a term to WHNF.
--
eval :: Term -> Eval Value
eval :: Term -> Eval Value
eval = \case
  Var Id
i           -> Id -> Eval Value
evalVar Id
i
  Literal Literal
lit     -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Literal -> Value
VLiteral Literal
lit)
  Data DataCon
dc         -> DataCon -> Eval Value
evalData DataCon
dc
  Prim PrimInfo
pr         -> PrimInfo -> Eval Value
evalPrim PrimInfo
pr
  Lam Id
i Term
x         -> Id -> Term -> Eval Value
evalLam Id
i Term
x
  TyLam TyVar
i Term
x       -> TyVar -> Term -> Eval Value
evalTyLam TyVar
i Term
x
  App Term
x Term
y         -> Term -> Arg Term -> Eval Value
evalApp Term
x (Term -> Arg Term
forall a b. a -> Either a b
Left Term
y)
  TyApp Term
x Type
ty      -> Term -> Arg Term -> Eval Value
evalApp Term
x (Type -> Arg Term
forall a b. b -> Either a b
Right Type
ty)
  Let Bind Term
bs Term
x        -> Bind Term -> Term -> Eval Value
evalLet Bind Term
bs Term
x
  Case Term
x Type
ty [Alt]
alts  -> Term -> Type -> [Alt] -> Eval Value
evalCase Term
x Type
ty [Alt]
alts
  Cast Term
x Type
a Type
b      -> Term -> Type -> Type -> Eval Value
evalCast Term
x Type
a Type
b
  Tick TickInfo
tick Term
x     -> TickInfo -> Term -> Eval Value
evalTick TickInfo
tick Term
x

delayEval :: Term -> Eval Value
delayEval :: Term -> Eval Value
delayEval = \case
  Literal Literal
lit -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Literal -> Value
VLiteral Literal
lit)
  Lam Id
i Term
x -> Id -> Term -> Eval Value
evalLam Id
i Term
x
  TyLam TyVar
i Term
x -> TyVar -> Term -> Eval Value
evalTyLam TyVar
i Term
x
  Tick TickInfo
t Term
x -> (Value -> TickInfo -> Value) -> TickInfo -> Value -> Value
forall a b c. (a -> b -> c) -> b -> a -> c
flip Value -> TickInfo -> Value
VTick TickInfo
t (Value -> Value) -> Eval Value -> Eval Value
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Term -> Eval Value
delayEval Term
x
  Term
term -> Term -> LocalEnv -> Value
VThunk Term
term (LocalEnv -> Value) -> Eval LocalEnv -> Eval Value
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Eval LocalEnv
getLocalEnv

forceEval :: Value -> Eval Value
forceEval :: Value -> Eval Value
forceEval = [(TyVar, Type)] -> [(Id, Value)] -> Value -> Eval Value
forceEvalWith [] []

forceEvalWith :: [(TyVar, Type)] -> [(Id, Value)] -> Value -> Eval Value
forceEvalWith :: [(TyVar, Type)] -> [(Id, Value)] -> Value -> Eval Value
forceEvalWith [(TyVar, Type)]
tvs [(Id, Value)]
ids = \case
  VThunk Term
term LocalEnv
env -> do
    tvs' <- ((TyVar, Type) -> Eval (TyVar, Type))
-> [(TyVar, Type)] -> Eval [(TyVar, Type)]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Type -> Eval Type) -> (TyVar, Type) -> Eval (TyVar, Type)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> (TyVar, a) -> f (TyVar, b)
traverse Type -> Eval Type
evalType) [(TyVar, Type)]
tvs
    setLocalEnv env (withTyVars tvs' . withIds ids $ eval term)

  Value
value -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Value
value

delayArg :: Arg Term -> Eval (Arg Value)
delayArg :: Arg Term -> Eval (Arg Value)
delayArg = (Term -> Eval Value)
-> (Type -> Eval Type) -> Arg Term -> Eval (Arg Value)
forall (f :: Type -> Type) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> Either a b -> f (Either c d)
forall (t :: Type -> Type -> Type) (f :: Type -> Type) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Term -> Eval Value
delayEval Type -> Eval Type
evalType

delayArgs :: Args Term -> Eval (Args Value)
delayArgs :: Args Term -> Eval (Args Value)
delayArgs = (Arg Term -> Eval (Arg Value)) -> Args Term -> Eval (Args Value)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Arg Term -> Eval (Arg Value)
delayArg

evalType :: Type -> Eval Type
evalType :: Type -> Eval Type
evalType Type
ty = do
  tcm <- Eval TyConMap
getTyConMap
  subst <- getTvSubst

  pure (normalizeType tcm (substTy subst ty))

evalVar :: Id -> Eval Value
evalVar :: Id -> Eval Value
evalVar Id
i
  | Id -> Bool
forall a. Var a -> Bool
isLocalId Id
i = Id -> Eval Value
lookupLocal Id
i
  | Bool
otherwise   = Id -> Eval Value
lookupGlobal Id
i

lookupLocal :: Id -> Eval Value
lookupLocal :: Id -> Eval Value
lookupLocal Id
i = do
  var <- Id -> Eval (Maybe Value)
findId Id
i
  varTy <- evalType (varType i)
  let i' = Id
i { varType = varTy }

  case var of
    Just Value
x  -> do
      workFree <- Value -> Eval Bool
workFreeValue Value
x
      if workFree then forceEval x else pure (VNeutral (NeVar i'))

    Maybe Value
Nothing -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Neutral Value -> Value
VNeutral (Id -> Neutral Value
forall a. Id -> Neutral a
NeVar Id
i'))

lookupGlobal :: Id -> Eval Value
lookupGlobal :: Id -> Eval Value
lookupGlobal Id
i = do
  -- inScope <- getInScope
  fuel <- Eval Word
getFuel
  var <- findBinding i

  case var of
    Just Binding Value
x
      -- The binding cannot be inlined. Note that this is limited to bindings
      -- which are not primitives in Clash, as these must be marked NOINLINE.
      |  InlineSpec -> Bool
isNoInline (Binding Value -> InlineSpec
forall a. Binding a -> InlineSpec
bindingSpec Binding Value
x)
      ,  Binding Value -> IsPrim
forall a. Binding a -> IsPrim
bindingIsPrim Binding Value
x IsPrim -> IsPrim -> Bool
forall a. Eq a => a -> a -> Bool
== IsPrim
IsFun
      -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Neutral Value -> Value
VNeutral (Id -> Neutral Value
forall a. Id -> Neutral a
NeVar Id
i))

      -- There is no fuel, meaning no more inlining can occur.
      |  Word
fuel Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
0
      -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Neutral Value -> Value
VNeutral (Id -> Neutral Value
forall a. Id -> Neutral a
NeVar Id
i))

      -- Inlining can occur, using one unit of fuel in the process.
      |  Bool
otherwise
      -> Id -> Eval Value -> Eval Value
forall a. Id -> Eval a -> Eval a
withContext Id
i (Eval Value -> Eval Value)
-> (Eval Value -> Eval Value) -> Eval Value -> Eval Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Eval Value -> Eval Value
forall a. Eval a -> Eval a
withFuel (Eval Value -> Eval Value) -> Eval Value -> Eval Value
forall a b. (a -> b) -> a -> b
$ do
           val <- Value -> Eval Value
forceEval (Binding Value -> Value
forall a. Binding a -> a
bindingTerm Binding Value
x)
           replaceBinding (x { bindingTerm = val })
           pure val

    Maybe (Binding Value)
Nothing
      -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Neutral Value -> Value
VNeutral (Id -> Neutral Value
forall a. Id -> Neutral a
NeVar Id
i))

evalData :: DataCon -> Eval Value
evalData :: DataCon -> Eval Value
evalData DataCon
dc
  | Type -> Args (ZonkAny 1) -> Bool
forall a. Type -> Args a -> Bool
fullyApplied (DataCon -> Type
dcType DataCon
dc) [] =
      DataCon -> Args Value -> LocalEnv -> Value
VData DataCon
dc [] (LocalEnv -> Value) -> Eval LocalEnv -> Eval Value
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Eval LocalEnv
getLocalEnv

  | Bool
otherwise =
      Term -> Eval Term
etaExpand (DataCon -> Term
Data DataCon
dc) Eval Term -> (Term -> Eval Value) -> Eval Value
forall a b. Eval a -> (a -> Eval b) -> Eval b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Term -> Eval Value
eval

evalPrim :: PrimInfo -> Eval Value
evalPrim :: PrimInfo -> Eval Value
evalPrim PrimInfo
pr
  | Type -> Args (ZonkAny 0) -> Bool
forall a. Type -> Args a -> Bool
fullyApplied (PrimInfo -> Type
primType PrimInfo
pr) [] =
      PrimInfo -> Args Value -> Eval Value
evalPrimOp PrimInfo
pr []

  | Bool
otherwise =
      Term -> Eval Term
etaExpand (PrimInfo -> Term
Prim PrimInfo
pr) Eval Term -> (Term -> Eval Value) -> Eval Value
forall a b. Eval a -> (a -> Eval b) -> Eval b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Term -> Eval Value
eval

-- TODO Hook up to primitive evaluation skeleton
evalPrimOp :: PrimInfo -> Args Value -> Eval Value
evalPrimOp :: PrimInfo -> Args Value -> Eval Value
evalPrimOp PrimInfo
pr Args Value
args = Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Neutral Value -> Value
VNeutral (PrimInfo -> Args Value -> Neutral Value
forall a. PrimInfo -> Args a -> Neutral a
NePrim PrimInfo
pr Args Value
args))

fullyApplied :: Type -> Args a -> Bool
fullyApplied :: forall a. Type -> Args a -> Bool
fullyApplied Type
ty Args a
args =
  [Either TyVar Type] -> ConTag
forall a. [a] -> ConTag
forall (t :: Type -> Type) a. Foldable t => t a -> ConTag
length (([Either TyVar Type], Type) -> [Either TyVar Type]
forall a b. (a, b) -> a
fst (([Either TyVar Type], Type) -> [Either TyVar Type])
-> ([Either TyVar Type], Type) -> [Either TyVar Type]
forall a b. (a -> b) -> a -> b
$ Type -> ([Either TyVar Type], Type)
splitFunForallTy Type
ty) ConTag -> ConTag -> Bool
forall a. Eq a => a -> a -> Bool
== Args a -> ConTag
forall a. [a] -> ConTag
forall (t :: Type -> Type) a. Foldable t => t a -> ConTag
length Args a
args

etaExpand :: Term -> Eval Term
etaExpand :: Term -> Eval Term
etaExpand Term
term = do
  tcm <- Eval TyConMap
getTyConMap

  case collectArgs term of
    x :: (Term, Args Term)
x@(Data DataCon
dc, Args Term
_) -> TyConMap -> Type -> (Term, Args Term) -> Eval Term
expand TyConMap
tcm (DataCon -> Type
dcType DataCon
dc) (Term, Args Term)
x
    x :: (Term, Args Term)
x@(Prim PrimInfo
pr, Args Term
_) -> TyConMap -> Type -> (Term, Args Term) -> Eval Term
expand TyConMap
tcm (PrimInfo -> Type
primType PrimInfo
pr) (Term, Args Term)
x
    (Term, Args Term)
_ -> Term -> Eval Term
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Term
term
 where
  etaNameOf :: Either a Type -> Eval (Either Id a)
etaNameOf =
    (a -> Eval (Either Id a))
-> (Type -> Eval (Either Id a))
-> Either a Type
-> Eval (Either Id a)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Either Id a -> Eval (Either Id a)
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Either Id a -> Eval (Either Id a))
-> (a -> Either Id a) -> a -> Eval (Either Id a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Either Id a
forall a b. b -> Either a b
Right) ((Id -> Either Id a) -> Eval Id -> Eval (Either Id a)
forall a b. (a -> b) -> Eval a -> Eval b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Id -> Either Id a
forall a b. a -> Either a b
Left (Eval Id -> Eval (Either Id a))
-> (Type -> Eval Id) -> Type -> Eval (Either Id a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Type -> Eval Id
getUniqueId Text
"eta")

  expand :: TyConMap -> Type -> (Term, Args Term) -> Eval Term
expand TyConMap
tcm Type
ty (Term
tm, Args Term
args) = do
    let ([Either TyVar Type]
missingTys, Type
_) = Type -> ([Either TyVar Type], Type)
splitFunForallTy (Term -> TyConMap -> Type -> Args Term -> Type
applyTypeToArgs Term
tm TyConMap
tcm Type
ty Args Term
args)
    missingArgs <- (Either TyVar Type -> Eval (Either Id TyVar))
-> [Either TyVar Type] -> Eval [Either Id TyVar]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Either TyVar Type -> Eval (Either Id TyVar)
forall {a}. Either a Type -> Eval (Either Id a)
etaNameOf [Either TyVar Type]
missingTys

    pure $ mkAbstraction
      (mkApps term (fmap (bimap Var VarTy) missingArgs))
      missingArgs

evalLam :: Id -> Term -> Eval Value
evalLam :: Id -> Term -> Eval Value
evalLam Id
i Term
x = do
  varTy <- Type -> Eval Type
evalType (Id -> Type
forall a. Var a -> Type
varType Id
i)
  let i' = Id
i { varType = varTy }
  env <- getLocalEnv

  pure (VLam i' x env)

evalTyLam :: TyVar -> Term -> Eval Value
evalTyLam :: TyVar -> Term -> Eval Value
evalTyLam TyVar
i Term
x = do
  varTy <- Type -> Eval Type
evalType (TyVar -> Type
forall a. Var a -> Type
varType TyVar
i)
  let i' = TyVar
i { varType = varTy }
  env <- getLocalEnv

  pure (VTyLam i' x env)

evalApp :: Term -> Arg Term -> Eval Value
evalApp :: Term -> Arg Term -> Eval Value
evalApp Term
x Arg Term
y
  | Data DataCon
dc <- Term
f
  = if Type -> Args Term -> Bool
forall a. Type -> Args a -> Bool
fullyApplied (DataCon -> Type
dcType DataCon
dc) Args Term
args
      then do
        argThunks <- Args Term -> Eval (Args Value)
delayArgs Args Term
args
        VData dc argThunks <$> getLocalEnv

      else Term -> Eval Term
etaExpand Term
term Eval Term -> (Term -> Eval Value) -> Eval Value
forall a b. Eval a -> (a -> Eval b) -> Eval b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Term -> Eval Value
eval

  | Prim PrimInfo
pr <- Term
f
  , [Either TyVar Type]
prArgs  <- ([Either TyVar Type], Type) -> [Either TyVar Type]
forall a b. (a, b) -> a
fst (([Either TyVar Type], Type) -> [Either TyVar Type])
-> ([Either TyVar Type], Type) -> [Either TyVar Type]
forall a b. (a -> b) -> a -> b
$ Type -> ([Either TyVar Type], Type)
splitFunForallTy (PrimInfo -> Type
primType PrimInfo
pr)
  , ConTag
numArgs <- [Either TyVar Type] -> ConTag
forall a. [a] -> ConTag
forall (t :: Type -> Type) a. Foldable t => t a -> ConTag
length [Either TyVar Type]
prArgs
  = case ConTag -> ConTag -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Args Term -> ConTag
forall a. [a] -> ConTag
forall (t :: Type -> Type) a. Foldable t => t a -> ConTag
length Args Term
args) ConTag
numArgs of
      Ordering
LT ->
        Term -> Eval Term
etaExpand Term
term Eval Term -> (Term -> Eval Value) -> Eval Value
forall a b. Eval a -> (a -> Eval b) -> Eval b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Term -> Eval Value
eval

      Ordering
EQ -> do
        argThunks <- Args Term -> Eval (Args Value)
delayArgs Args Term
args
        let tyVars = [Either TyVar Type] -> [TyVar]
forall a b. [Either a b] -> [a]
lefts [Either TyVar Type]
prArgs
            tyArgs = Args Term -> [Type]
forall a b. [Either a b] -> [b]
rights Args Term
args

        withTyVars (zip tyVars tyArgs) (evalPrimOp pr argThunks)

      Ordering
GT -> do
        let (Args Term
pArgs, Args Term
rArgs) = ConTag -> Args Term -> (Args Term, Args Term)
forall a. ConTag -> [a] -> ([a], [a])
splitAt ConTag
numArgs Args Term
args
        pArgThunks <- Args Term -> Eval (Args Value)
delayArgs Args Term
pArgs
        primRes <- evalPrimOp pr pArgThunks
        rArgThunks <- delayArgs rArgs

        foldM applyArg primRes rArgThunks

  | Bool
otherwise
  = Eval Value -> Eval Value
forall a. Eval a -> Eval a
preserveFuel (Eval Value -> Eval Value) -> Eval Value -> Eval Value
forall a b. (a -> b) -> a -> b
$ do
      evalF <- Term -> Eval Value
eval Term
f
      argThunks <- delayArgs args
      foldM applyArg evalF argThunks
 where
  term :: Term
term = (Term -> Term) -> (Type -> Term) -> Arg Term -> Term
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Term -> Term -> Term
App Term
x) (Term -> Type -> Term
TyApp Term
x) Arg Term
y
  (Term
f, Args Term
args, [TickInfo]
_ticks) = Term -> (Term, Args Term, [TickInfo])
collectArgsTicks Term
term

evalLet :: Bind Term -> Term -> Eval Value
evalLet :: Bind Term -> Term -> Eval Value
evalLet (NonRec Id
i Term
x) Term
body = do
  iTy <- Type -> Eval Type
evalType (Id -> Type
forall a. Var a -> Type
varType Id
i)
  eX  <- delayEval x
  wfX <- workFreeValue eX

  eBody <- withId i eX (eval body)

  -- Only keep the let binding if it performs work.
  if wfX
    then pure eBody
    else pure (VNeutral (NeLet (NonRec i{varType=iTy} eX) eBody))

evalLet (Rec [(Id, Term)]
xs) Term
body = do
  binds <- ((Id, Term) -> Eval (Id, Value))
-> [(Id, Term)] -> Eval [(Id, Value)]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Id, Term) -> Eval (Id, Value)
forall {a}. (Var a, Term) -> Eval (Var a, Value)
evalBind [(Id, Term)]
xs
  eBody <- withIds binds (eval body)

  pure (VNeutral (NeLet (Rec binds) eBody))
 where
  evalBind :: (Var a, Term) -> Eval (Var a, Value)
evalBind (Var a
i, Term
x) = do
    iTy <- Type -> Eval Type
evalType (Var a -> Type
forall a. Var a -> Type
varType Var a
i)
    eX <- delayEval x

    pure (i{varType=iTy}, eX)

evalCase :: Term -> Type -> [Alt] -> Eval Value
evalCase :: Term -> Type -> [Alt] -> Eval Value
evalCase Term
term Type
ty [Alt]
as = do
  subject <- Term -> Eval Value
delayEval Term
term
  resTy <- evalType ty
  alts <- delayAlts as

  caseCon subject resTy alts

-- | Attempt to apply the case-of-known-constructor transformation on a case
-- expression. If no suitable alternative can be chosen, attempt to transform
-- the case expression to try and expose more opportunities.
--
caseCon :: Value -> Type -> [(Pat, Value)] -> Eval Value
caseCon :: Value -> Type -> [(Pat, Value)] -> Eval Value
caseCon Value
subject Type
ty [(Pat, Value)]
alts = do
  forcedSubject <- Eval Value -> Eval Value
forall a. Eval a -> Eval a
keepLifted (Value -> Eval Value
forceEval Value
subject)

  -- If the subject is undefined, the whole expression is undefined.
  case isUndefinedX forcedSubject of
   Bool
True -> Term -> Eval Value
eval (Term -> Type -> Term
TyApp (PrimInfo -> Term
Prim PrimInfo
NP.undefinedX) Type
ty)
   Bool
False -> case Value -> Bool
isUndefined Value
forcedSubject of
    Bool
True -> Term -> Eval Value
eval (Term -> Type -> Term
TyApp (PrimInfo -> Term
Prim PrimInfo
NP.undefined) Type
ty)
    Bool
False ->
      case Value -> Value
stripValue Value
forcedSubject of
        -- Known literal: attempt to match or throw an error.
        VLiteral Literal
lit -> do
          let def :: a
def = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char]
"caseCon: No pattern matched " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Literal -> [Char]
forall a. Show a => a -> [Char]
show Literal
lit [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" in " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [(Pat, Value)] -> [Char]
forall a. Show a => a -> [Char]
show [(Pat, Value)]
alts)
          match <- ((Pat, Value) -> Eval PatResult)
-> [(Pat, Value)] -> Eval PatResult
findBestAlt (Literal -> (Pat, Value) -> Eval PatResult
matchLiteral Literal
lit) [(Pat, Value)]
alts
          evalAlt def match

        -- Known data constructor: attempt to match or throw an error.
        -- The environment here is the same as the current environment.
        VData DataCon
dc Args Value
args LocalEnv
_env -> do
          let def :: a
def = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char]
"caseCon: No pattern matched " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> DataCon -> [Char]
forall a. Show a => a -> [Char]
show DataCon
dc [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" in " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [(Pat, Value)] -> [Char]
forall a. Show a => a -> [Char]
show [(Pat, Value)]
alts)
          match <- ((Pat, Value) -> Eval PatResult)
-> [(Pat, Value)] -> Eval PatResult
findBestAlt (DataCon -> Args Value -> (Pat, Value) -> Eval PatResult
matchData DataCon
dc Args Value
args) [(Pat, Value)]
alts
          evalAlt def match

        -- Neutral primitives may be clash primitives which are treated as
        -- values, like fromInteger# for various types in clash-prelude.
        VNeutral (NePrim PrimInfo
pr Args Value
args) -> do
          let def :: Value
def = Neutral Value -> Value
VNeutral (Value -> Type -> [(Pat, Value)] -> Neutral Value
forall a. a -> Type -> [(Pat, a)] -> Neutral a
NeCase Value
forcedSubject Type
ty [(Pat, Value)]
alts)
          match <- ((Pat, Value) -> Eval PatResult)
-> [(Pat, Value)] -> Eval PatResult
findBestAlt (PrimInfo -> Args Value -> (Pat, Value) -> Eval PatResult
matchClashPrim PrimInfo
pr Args Value
args) [(Pat, Value)]
alts
          evalAlt def match

        -- We know nothing: attempt case-of-case / case-of-let.
        Value
_ -> Value -> Type -> [(Pat, Value)] -> Eval Value
tryTransformCase Value
forcedSubject Type
ty [(Pat, Value)]
alts

-- | Attempt to apply a transformation to a case expression to expose more
-- opportunities for caseCon. If no transformations can be applied the
-- case expression can only be neutral.
--
tryTransformCase :: Value -> Type -> [(Pat, Value)] -> Eval Value
tryTransformCase :: Value -> Type -> [(Pat, Value)] -> Eval Value
tryTransformCase Value
subject Type
ty [(Pat, Value)]
alts =
  case Value -> Value
stripValue Value
subject of
    -- A case of case: pull out the inner case expression if possible and
    -- attempt caseCon on the new case expression.
    VNeutral (NeCase Value
innerSubject Type
_ [(Pat, Value)]
innerAlts) -> do
      forcedAlts <- [(Pat, Value)] -> Eval [(Pat, Value)]
forceAlts [(Pat, Value)]
innerAlts

      if all (isKnown . snd) forcedAlts
       then let asCase Value
v = Neutral Value -> Value
VNeutral (Value -> Type -> [(Pat, Value)] -> Neutral Value
forall a. a -> Type -> [(Pat, a)] -> Neutral a
NeCase Value
v Type
ty [(Pat, Value)]
alts)
                newAlts  = (Value -> Value) -> (Pat, Value) -> (Pat, Value)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Value -> Value
asCase ((Pat, Value) -> (Pat, Value)) -> [(Pat, Value)] -> [(Pat, Value)]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Pat, Value)]
innerAlts
             in caseCon innerSubject ty newAlts

        else pure (VNeutral (NeCase subject ty alts))

    -- A case of let: Pull out the let expression if possible and attempt
    -- caseCon on the new case expression.
    VNeutral (NeLet Bind Value
bindings Value
innerSubject) -> do
      newCase <- Value -> Type -> [(Pat, Value)] -> Eval Value
caseCon Value
innerSubject Type
ty [(Pat, Value)]
alts
      pure (VNeutral (NeLet bindings newCase))

    -- There is no way to continue evaluating the case, do nothing.
    -- TODO elimExistentials here.
    Value
_ -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Neutral Value -> Value
VNeutral (Value -> Type -> [(Pat, Value)] -> Neutral Value
forall a. a -> Type -> [(Pat, a)] -> Neutral a
NeCase Value
subject Type
ty [(Pat, Value)]
alts))
 where
  -- We only care about case of case if alternatives of the inner case
  -- expression correspond to something we can do caseCon on.
  --
  -- TODO We may also care if it is another case of case?
  --
  isKnown :: Value -> Bool
isKnown = \case
    VNeutral (NePrim PrimInfo
pr Args Value
_) ->
      PrimInfo -> Text
primName PrimInfo
pr Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem`
        [ Text
"Clash.Sized.Internal.BitVector.fromInteger##"
        , Text
"Clash.Sized.Internal.BitVector.fromInteger#"
        , Text
"Clash.Sized.Internal.Index.fromInteger#"
        , Text
"Clash.Sized.Internal.Signed.fromInteger#"
        , Text
"Clash.Sized.Internal.Unsigned.fromInteger#"
        ]

    VLiteral{} -> Bool
True
    VData{} -> Bool
True
    Value
_ -> Bool
False

delayAlts :: [Alt] -> Eval [(Pat, Value)]
delayAlts :: [Alt] -> Eval [(Pat, Value)]
delayAlts = (Alt -> Eval (Pat, Value)) -> [Alt] -> Eval [(Pat, Value)]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Pat -> Eval Pat)
-> (Term -> Eval Value) -> Alt -> Eval (Pat, Value)
forall (f :: Type -> Type) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> (a, b) -> f (c, d)
forall (t :: Type -> Type -> Type) (f :: Type -> Type) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Pat -> Eval Pat
delayPat Term -> Eval Value
delayEval)
 where
  delayPat :: Pat -> Eval Pat
delayPat = \case
    DataPat DataCon
dc [TyVar]
tvs [Id]
ids -> do
      tvsTys <- (Type -> Eval Type) -> [Type] -> Eval [Type]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Type -> Eval Type
evalType ((TyVar -> Type) -> [TyVar] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap TyVar -> Type
forall a. Var a -> Type
varType [TyVar]
tvs)
      idsTys <- traverse evalType (fmap varType ids)

      let setTy Var a
v Type
ty = Var a
v { varType = ty }
          tvs' = (TyVar -> Type -> TyVar) -> [TyVar] -> [Type] -> [TyVar]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TyVar -> Type -> TyVar
forall {a}. Var a -> Type -> Var a
setTy [TyVar]
tvs [Type]
tvsTys
          ids' = (Id -> Type -> Id) -> [Id] -> [Type] -> [Id]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Id -> Type -> Id
forall {a}. Var a -> Type -> Var a
setTy [Id]
ids [Type]
idsTys

      pure (DataPat dc tvs' ids')

    Pat
pat -> Pat -> Eval Pat
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Pat
pat

forceAlts :: [(Pat, Value)] -> Eval [(Pat, Value)]
forceAlts :: [(Pat, Value)] -> Eval [(Pat, Value)]
forceAlts = ((Pat, Value) -> Eval (Pat, Value))
-> [(Pat, Value)] -> Eval [(Pat, Value)]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((Value -> Eval Value) -> (Pat, Value) -> Eval (Pat, Value)
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> (Pat, a) -> f (Pat, b)
traverse Value -> Eval Value
forceEval)

data PatResult
  = Match   (Pat, Value) [(TyVar, Type)] [(Id, Value)]
  | NoMatch

evalAlt :: Value -> PatResult -> Eval Value
evalAlt :: Value -> PatResult -> Eval Value
evalAlt Value
def = \case
  Match (Pat
_, Value
val) [(TyVar, Type)]
tvs [(Id, Value)]
ids ->
    [(TyVar, Type)] -> [(Id, Value)] -> Value -> Eval Value
forceEvalWith [(TyVar, Type)]
tvs [(Id, Value)]
ids Value
val

  PatResult
NoMatch -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Value
def

matchLiteral :: Literal -> (Pat, Value) -> Eval PatResult
matchLiteral :: Literal -> (Pat, Value) -> Eval PatResult
matchLiteral Literal
lit alt :: (Pat, Value)
alt@(Pat
pat, Value
_) =
  case Pat
pat of
    DataPat DataCon
dc [] [Id
i]
      |  IntegerLiteral Integer
n <- Literal
lit
      -> case Integer
n of
#if MIN_VERSION_base(4,15,0)
           IS Int#
_
#else
           S# _
#endif
             | DataCon -> ConTag
dcTag DataCon
dc ConTag -> ConTag -> Bool
forall a. Eq a => a -> a -> Bool
== ConTag
1 -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (PatResult -> Eval PatResult) -> PatResult -> Eval PatResult
forall a b. (a -> b) -> a -> b
$ (Pat, Value) -> [(TyVar, Type)] -> [(Id, Value)] -> PatResult
Match (Pat, Value)
alt [] [(Id
i, Literal -> Value
VLiteral (Integer -> Literal
IntLiteral Integer
n))]

#if MIN_VERSION_base(4,15,0)
           IP ByteArray#
bn
#else
           Jp# bn
#endif
             | DataCon -> ConTag
dcTag DataCon
dc ConTag -> ConTag -> Bool
forall a. Eq a => a -> a -> Bool
== ConTag
2 -> Id -> ByteArray# -> Eval PatResult
matchBigNat Id
i ByteArray#
bn

#if MIN_VERSION_base(4,15,0)
           IN ByteArray#
bn
#else
           Jn# bn
#endif
             | DataCon -> ConTag
dcTag DataCon
dc ConTag -> ConTag -> Bool
forall a. Eq a => a -> a -> Bool
== ConTag
3 -> Id -> ByteArray# -> Eval PatResult
matchBigNat Id
i ByteArray#
bn

           Integer
_ -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure PatResult
NoMatch

      |  NaturalLiteral Integer
n <- Literal
lit
      -> case Integer
n of
#if MIN_VERSION_base(4,15,0)
           IS Int#
_
#else
           S# _
#endif
             | DataCon -> ConTag
dcTag DataCon
dc ConTag -> ConTag -> Bool
forall a. Eq a => a -> a -> Bool
== ConTag
1 -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (PatResult -> Eval PatResult) -> PatResult -> Eval PatResult
forall a b. (a -> b) -> a -> b
$ (Pat, Value) -> [(TyVar, Type)] -> [(Id, Value)] -> PatResult
Match (Pat, Value)
alt [] [(Id
i, Literal -> Value
VLiteral (Integer -> Literal
WordLiteral Integer
n))]

#if MIN_VERSION_base(4,15,0)
           IP ByteArray#
bn
#else
           Jp# bn
#endif
             | DataCon -> ConTag
dcTag DataCon
dc ConTag -> ConTag -> Bool
forall a. Eq a => a -> a -> Bool
== ConTag
2 -> Id -> ByteArray# -> Eval PatResult
matchBigNat Id
i ByteArray#
bn

           Integer
_ -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure PatResult
NoMatch

    LitPat Literal
n
      | Literal
lit Literal -> Literal -> Bool
forall a. Eq a => a -> a -> Bool
== Literal
n -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (PatResult -> Eval PatResult) -> PatResult -> Eval PatResult
forall a b. (a -> b) -> a -> b
$ (Pat, Value) -> [(TyVar, Type)] -> [(Id, Value)] -> PatResult
Match (Pat, Value)
alt [] []

    Pat
DefaultPat -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (PatResult -> Eval PatResult) -> PatResult -> Eval PatResult
forall a b. (a -> b) -> a -> b
$ (Pat, Value) -> [(TyVar, Type)] -> [(Id, Value)] -> PatResult
Match (Pat, Value)
alt [] []

    Pat
_ -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure PatResult
NoMatch
 where
  -- Somewhat of a hack: We find the constructor for BigNat and apply a
  -- ByteArray literal made from the given ByteArray to it.
#if MIN_VERSION_base(4,15,0)
  matchBigNat :: Id -> ByteArray# -> Eval PatResult
matchBigNat Id
i ByteArray#
ba = do
#else
  matchBigNat i (BN# ba) = do
#endif
    tcm <- Eval TyConMap
getTyConMap
    let bnDcM = do
          integerTcName <- ((TyConName, [Type]) -> TyConName)
-> Maybe (TyConName, [Type]) -> Maybe TyConName
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (TyConName, [Type]) -> TyConName
forall a b. (a, b) -> a
fst (Type -> Maybe (TyConName, [Type])
splitTyConAppM Type
integerPrimTy)
          [_, jpDc, _]  <- pure (tyConDataCons (UniqMap.find integerTcName tcm))
          ([bnTy], _)   <- pure (splitFunTys tcm (dcType jpDc))
          bnTcName      <- fmap fst (splitTyConAppM bnTy)
          listToMaybe (tyConDataCons (UniqMap.find bnTcName tcm))

        bnDc = DataCon -> Maybe DataCon -> DataCon
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> DataCon
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot find BigNat constructor") Maybe DataCon
bnDcM

    let arr = ByteArray -> Literal
ByteArrayLiteral (ByteArray# -> ByteArray
ByteArray ByteArray#
ba)
    val <- VData bnDc [Left (VLiteral arr)] <$> getLocalEnv

    pure (Match alt [] [(i, val)])

matchData :: DataCon -> Args Value -> (Pat, Value) -> Eval PatResult
matchData :: DataCon -> Args Value -> (Pat, Value) -> Eval PatResult
matchData DataCon
dc Args Value
args alt :: (Pat, Value)
alt@(Pat
pat, Value
_) =
  case Pat
pat of
    DataPat DataCon
c [TyVar]
tvs [Id]
ids
      |  DataCon
dc DataCon -> DataCon -> Bool
forall a. Eq a => a -> a -> Bool
== DataCon
c
      -> do let ([(Id, Value)]
tms, [(TyVar, Type)]
tys) = ([Value] -> [(Id, Value)])
-> ([Type] -> [(TyVar, Type)])
-> ([Value], [Type])
-> ([(Id, Value)], [(TyVar, Type)])
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ([Id] -> [Value] -> [(Id, Value)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
ids) ([TyVar] -> [Type] -> [(TyVar, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TyVar]
tvs) (Args Value -> ([Value], [Type])
forall a b. [Either a b] -> ([a], [b])
partitionEithers Args Value
args)
            PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ((Pat, Value) -> [(TyVar, Type)] -> [(Id, Value)] -> PatResult
Match (Pat, Value)
alt [(TyVar, Type)]
tys [(Id, Value)]
tms)

    Pat
DefaultPat -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ((Pat, Value) -> [(TyVar, Type)] -> [(Id, Value)] -> PatResult
Match (Pat, Value)
alt [] [])
    Pat
_ -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure PatResult
NoMatch

-- TODO Should this also consider DataPat and data constructors?
-- The old evaluator did not, but matchData wouldn't cover it.
--
matchClashPrim :: PrimInfo -> Args Value -> (Pat, Value) -> Eval PatResult
matchClashPrim :: PrimInfo -> Args Value -> (Pat, Value) -> Eval PatResult
matchClashPrim PrimInfo
pr Args Value
args alt :: (Pat, Value)
alt@(Pat
pat, Value
_) =
  case Pat
pat of
    LitPat Literal
lit
      -- Bit literals
      |  PrimInfo -> Text
primName PrimInfo
pr Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Sized.BitVector.fromInteger##"
      ,  [Left Value
mask, Left Value
val] <- Args Value
args
      -> do VLiteral (WordLiteral m) <- Value -> Eval Value
forceEval Value
mask
            VLiteral l <- forceEval val

            if m == 0 && l == lit
              then pure (Match alt [] [])
              else pure NoMatch

      -- BitVector literals
      |  PrimInfo -> Text
primName PrimInfo
pr Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Sized.BitVector.fromInteger#"
      ,  [Right Type
_n, Left Value
_knN, Left Value
mask, Left Value
val] <- Args Value
args
      -> do VLiteral (NaturalLiteral m) <- Value -> Eval Value
forceEval Value
mask
            VLiteral l <- forceEval val

            if m == 0 && l == lit
              then pure (Match alt [] [])
              else pure NoMatch

      -- Sized integer / natural literals
      |  PrimInfo -> Text
primName PrimInfo
pr Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Text]
clashSizedNumbers
      ,  [Right Type
_n, Left Value
_knN, Left Value
val] <- Args Value
args
      -> do VLiteral l <- Value -> Eval Value
forceEval Value
val

            if l == lit
              then pure (Match alt [] [])
              else pure NoMatch

    -- The primitive is not a literal from clash-prelude
    Pat
_ -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure PatResult
NoMatch
 where
  clashSizedNumbers :: [Text]
clashSizedNumbers =
    [ Text
"Clash.Sized.Internal.Index.fromInteger#"
    , Text
"Clash.Sized.Internal.Signed.fromInteger#"
    , Text
"Clash.Sized.Internal.Unsigned.fromInteger#"
    ]

-- | Given a predicate to check if an alternative is a match, find the best
-- alternative that matches the predicate. Best is defined as being the most
-- specific matching pattern (meaning DefaultPat is only used if no other
-- pattern tried matches).
--
findBestAlt
  :: ((Pat, Value) -> Eval PatResult)
  -> [(Pat, Value)]
  -> Eval PatResult
findBestAlt :: ((Pat, Value) -> Eval PatResult)
-> [(Pat, Value)] -> Eval PatResult
findBestAlt (Pat, Value) -> Eval PatResult
checkAlt =
  PatResult -> [(Pat, Value)] -> Eval PatResult
go PatResult
NoMatch
 where
  go :: PatResult -> [(Pat, Value)] -> Eval PatResult
go !PatResult
acc [] = PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure PatResult
acc
  go !PatResult
acc ((Pat, Value)
a:[(Pat, Value)]
as) = do
    match <- (Pat, Value) -> Eval PatResult
checkAlt (Pat, Value)
a
    case match of
      Match (Pat
pat, Value
_term) [(TyVar, Type)]
_tvs [(Id, Value)]
_ids
        | Pat
pat Pat -> Pat -> Bool
forall a. Eq a => a -> a -> Bool
== Pat
DefaultPat -> PatResult -> [(Pat, Value)] -> Eval PatResult
go PatResult
match [(Pat, Value)]
as
        | Bool
otherwise -> PatResult -> Eval PatResult
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure PatResult
match

      PatResult
NoMatch -> PatResult -> [(Pat, Value)] -> Eval PatResult
go PatResult
acc [(Pat, Value)]
as

evalCast :: Term -> Type -> Type -> Eval Value
evalCast :: Term -> Type -> Type -> Eval Value
evalCast Term
x Type
a Type
b = Value -> Type -> Type -> Value
VCast (Value -> Type -> Type -> Value)
-> Eval Value -> Eval (Type -> Type -> Value)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Term -> Eval Value
eval Term
x Eval (Type -> Type -> Value) -> Eval Type -> Eval (Type -> Value)
forall a b. Eval (a -> b) -> Eval a -> Eval b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> Eval Type
evalType Type
a Eval (Type -> Value) -> Eval Type -> Eval Value
forall a b. Eval (a -> b) -> Eval a -> Eval b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> Eval Type
evalType Type
b

evalTick :: TickInfo -> Term -> Eval Value
evalTick :: TickInfo -> Term -> Eval Value
evalTick TickInfo
tick Term
x = Value -> TickInfo -> Value
VTick (Value -> TickInfo -> Value)
-> Eval Value -> Eval (TickInfo -> Value)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Term -> Eval Value
eval Term
x Eval (TickInfo -> Value) -> Eval TickInfo -> Eval Value
forall a b. Eval (a -> b) -> Eval a -> Eval b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> TickInfo -> Eval TickInfo
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure TickInfo
tick

applyArg :: Value -> Arg Value -> Eval Value
applyArg :: Value -> Arg Value -> Eval Value
applyArg Value
val =
  (Value -> Eval Value)
-> (Type -> Eval Value) -> Arg Value -> Eval Value
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Value -> Value -> Eval Value
apply Value
val) (Value -> Type -> Eval Value
applyTy Value
val)

apply :: Value -> Value -> Eval Value
apply :: Value -> Value -> Eval Value
apply Value
val Value
arg = do
  tcm <- Eval TyConMap
getTyConMap
  forced <- forceEval val
  canApply <- workFreeValue arg

  case stripValue forced of
    -- If the LHS of application evaluates to a letrec, then add any bindings
    -- that do work to this letrec instead of creating a new one.
    VNeutral (NeLet Bind Value
bs Value
x)
      | Bool
canApply  -> do
          inner <- Value -> Value -> Eval Value
apply Value
x Value
arg
          pure (VNeutral (NeLet bs inner))

      | Bool
otherwise -> do
          varTy <- Type -> Eval Type
evalType (TyConMap -> Value -> Type
forall {a}. AsTerm a => TyConMap -> a -> Type
valueType TyConMap
tcm Value
arg)
          var <- getUniqueId "workArg" varTy
          inner <- apply x (VNeutral (NeVar var))
          pure (VNeutral (NeLet bs (VNeutral (NeLet (NonRec var arg) inner))))

    -- If the LHS of application is neutral, make a letrec around the neutral
    -- application if the argument performs work.
    VNeutral Neutral Value
neu
      | Bool
canApply  -> Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Neutral Value -> Value
VNeutral (Neutral Value -> Value -> Neutral Value
forall a. Neutral a -> a -> Neutral a
NeApp Neutral Value
neu Value
arg))
      | Bool
otherwise -> do
          varTy <- Type -> Eval Type
evalType (TyConMap -> Value -> Type
forall {a}. AsTerm a => TyConMap -> a -> Type
valueType TyConMap
tcm Value
arg)
          var <- getUniqueId "workArg" varTy
          let inner = Neutral Value -> Value
VNeutral (Neutral Value -> Value -> Neutral Value
forall a. Neutral a -> a -> Neutral a
NeApp Neutral Value
neu (Neutral Value -> Value
VNeutral (Id -> Neutral Value
forall a. Id -> Neutral a
NeVar Id
var)))
          pure (VNeutral (NeLet (NonRec var arg) inner))

    -- If the LHS of application is a lambda, make a letrec with the name of
    -- the argument around the result of evaluation if it performs work.
    VLam Id
i Term
x LocalEnv
env
      | Bool
canApply  -> LocalEnv -> Eval Value -> Eval Value
forall a. LocalEnv -> Eval a -> Eval a
setLocalEnv LocalEnv
env (Eval Value -> Eval Value) -> Eval Value -> Eval Value
forall a b. (a -> b) -> a -> b
$ Id -> Value -> Eval Value -> Eval Value
forall a. Id -> Value -> Eval a -> Eval a
withId Id
i Value
arg (Term -> Eval Value
eval Term
x)
      | Bool
otherwise -> LocalEnv -> Eval Value -> Eval Value
forall a. LocalEnv -> Eval a -> Eval a
setLocalEnv LocalEnv
env (Eval Value -> Eval Value) -> Eval Value -> Eval Value
forall a b. (a -> b) -> a -> b
$ do
          inner <- Id -> Value -> Eval Value -> Eval Value
forall a. Id -> Value -> Eval a -> Eval a
withId Id
i Value
arg (Term -> Eval Value
eval Term
x)
          pure (VNeutral (NeLet (NonRec i arg) inner))

    Value
f ->
      [Char] -> Eval Value
forall a. HasCallStack => [Char] -> a
error ([Char]
"apply: Cannot apply " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Value -> [Char]
forall a. Show a => a -> [Char]
show Value
arg [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" to " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Value -> [Char]
forall a. Show a => a -> [Char]
show Value
f)
 where
  -- TODO Write an instance for InferType Value and use that instead
  valueType :: TyConMap -> a -> Type
valueType TyConMap
tcm = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm (Term -> Type) -> (a -> Term) -> a -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Term
forall a. AsTerm a => a -> Term
asTerm

applyTy :: Value -> Type -> Eval Value
applyTy :: Value -> Type -> Eval Value
applyTy Value
val Type
ty = do
  forcedVal <- Value -> Eval Value
forceEval Value
val
  argTy <- evalType ty

  case stripValue forcedVal of
    VNeutral Neutral Value
n ->
      Value -> Eval Value
forall a. a -> Eval a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Neutral Value -> Value
VNeutral (Neutral Value -> Type -> Neutral Value
forall a. Neutral a -> Type -> Neutral a
NeTyApp Neutral Value
n Type
argTy))

    VTyLam TyVar
i Term
x LocalEnv
env ->
      LocalEnv -> Eval Value -> Eval Value
forall a. LocalEnv -> Eval a -> Eval a
setLocalEnv LocalEnv
env (Eval Value -> Eval Value) -> Eval Value -> Eval Value
forall a b. (a -> b) -> a -> b
$ TyVar -> Type -> Eval Value -> Eval Value
forall a. TyVar -> Type -> Eval a -> Eval a
withTyVar TyVar
i Type
argTy (Term -> Eval Value
eval Term
x)

    Value
f ->
      [Char] -> Eval Value
forall a. HasCallStack => [Char] -> a
error ([Char]
"applyTy: Cannot apply " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Type -> [Char]
forall a. Show a => a -> [Char]
show Type
argTy [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" to " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Value -> [Char]
forall a. Show a => a -> [Char]
show Value
f)