diff --git a/common/EvalConstants.hs b/common/EvalConstants.hs index 8b43e05..c51b16f 100644 --- a/common/EvalConstants.hs +++ b/common/EvalConstants.hs @@ -38,13 +38,15 @@ import Utils -- | Simplify an expression by constant folding, and also return whether it's a -- constant after that. -constantFold :: CSMR m => A.Expression -> m (A.Expression, Bool, ErrorReport) +constantFold :: (CSMR m, Die m) => A.Expression -> m (A.Expression, Bool, ErrorReport) constantFold e = do ps <- getCompState - let (e', msg) = case simplifyExpression ps e of - Left err -> (e, err) - Right val -> (val, (Nothing, "already folded")) - return (e', isConstant e', msg) + t <- typeOfExpression e + case runEvaluator ps (evalExpression e) of + Left err -> return (e, False, err) + Right val -> + do e' <- renderValue (findMeta e) t val + return (e', isConstant e', (Nothing, "already folded")) -- | Try to fold and evaluate an integer expression. -- If it's not a constant, return 'Nothing'. @@ -73,14 +75,6 @@ isConstantName n Just _ -> True Nothing -> False --- | Attempt to simplify an expression as far as possible by precomputing --- constant bits. -simplifyExpression :: CompState -> A.Expression -> Either ErrorReport A.Expression -simplifyExpression ps e - = case runEvaluator ps (evalExpression e) of - Left err -> Left err - Right val -> Right $ snd $ renderValue (findMeta e) val - --{{{ expression evaluator evalLiteral :: A.Expression -> EvalM OccValue evalLiteral (A.Literal m _ (A.ArrayLiteral _ [])) @@ -274,51 +268,62 @@ evalDyadic op _ _ = throwError (Nothing, "bad dyadic op: " ++ show op) --}}} --{{{ rendering values --- | Convert a value back into a literal. -renderValue :: Meta -> OccValue -> (A.Type, A.Expression) -renderValue m (OccBool True) = (A.Bool, A.True m) -renderValue m (OccBool False) = (A.Bool, A.False m) -renderValue m v = (t, A.Literal m t lr) - where (t, lr) = renderLiteral m v +-- | Convert an 'OccValue' back into a (literal) 'Expression'. +renderValue :: (CSMR m, Die m) => Meta -> A.Type -> OccValue -> m A.Expression +renderValue m _ (OccBool True) = return $ A.True m +renderValue m _ (OccBool False) = return $ A.False m +renderValue m t v = renderLiteral m t v >>* A.Literal m t -renderLiteral :: Meta -> OccValue -> (A.Type, A.LiteralRepr) -renderLiteral m (OccByte c) = (A.Byte, A.ByteLiteral m $ renderChar (chr $ fromIntegral c)) -renderLiteral m (OccUInt16 i) = (A.UInt16, A.IntLiteral m $ show i) -renderLiteral m (OccUInt32 i) = (A.UInt32, A.IntLiteral m $ show i) -renderLiteral m (OccUInt64 i) = (A.UInt64, A.IntLiteral m $ show i) -renderLiteral m (OccInt8 i) = (A.Int8, A.IntLiteral m $ show i) -renderLiteral m (OccInt i) = (A.Int, A.IntLiteral m $ show i) -renderLiteral m (OccInt16 i) = (A.Int16, A.IntLiteral m $ show i) -renderLiteral m (OccInt32 i) = (A.Int32, A.IntLiteral m $ show i) -renderLiteral m (OccInt64 i) = (A.Int64, A.IntLiteral m $ show i) -renderLiteral m (OccArray vs) - = (t, A.ArrayLiteral m aes) +-- | Convert an 'OccValue' back into a 'LiteralRepr'. +renderLiteral :: (CSMR m, Die m) => Meta -> A.Type -> OccValue -> m A.LiteralRepr +renderLiteral m t v + = case v of + OccByte c -> + return $ A.ByteLiteral m $ renderChar (chr $ fromIntegral c) + OccUInt16 i -> renderInt i + OccUInt32 i -> renderInt i + OccUInt64 i -> renderInt i + OccInt8 i -> renderInt i + OccInt i -> renderInt i + OccInt16 i -> renderInt i + OccInt32 i -> renderInt i + OccInt64 i -> renderInt i + OccArray vs -> renderArray vs + OccRecord _ vs -> renderRecord vs where - t = addDimensions [makeDimension m $ length vs] (head ts) - (ts, aes) = unzip $ map (renderLiteralArray m) vs -renderLiteral m (OccRecord n vs) - = (A.Record n, A.RecordLiteral m (map (snd . renderValue m) vs)) + renderChar :: Char -> String + renderChar '\'' = "*'" + renderChar '\"' = "*\"" + renderChar '*' = "**" + renderChar '\r' = "*c" + renderChar '\n' = "*n" + renderChar '\t' = "*t" + renderChar c + | (o < 32 || o > 127) = printf "*#%02x" o + | otherwise = [c] + where o = ord c -renderChar :: Char -> String -renderChar '\'' = "*'" -renderChar '\"' = "*\"" -renderChar '*' = "**" -renderChar '\r' = "*c" -renderChar '\n' = "*n" -renderChar '\t' = "*t" -renderChar c - | (o < 32 || o > 127) = printf "*#%02x" o - | otherwise = [c] - where o = ord c + renderInt :: (Show s, CSMR m, Die m) => s -> m A.LiteralRepr + renderInt i = return $ A.IntLiteral m $ show i -renderLiteralArray :: Meta -> OccValue -> (A.Type, A.ArrayElem) -renderLiteralArray m (OccArray vs) - = (t, A.ArrayElemArray aes) - where - t = addDimensions [makeDimension m $ length vs] (head ts) - (ts, aes) = unzip $ map (renderLiteralArray m) vs -renderLiteralArray m v - = (t, A.ArrayElemExpr e) - where - (t, e) = renderValue m v + renderArray :: (CSMR m, Die m) => [OccValue] -> m A.LiteralRepr + renderArray vs + = do subT <- trivialSubscriptType m t + aes <- mapM (renderArrayElem subT) vs + return $ A.ArrayLiteral m aes + + renderArrayElem :: (CSMR m, Die m) => A.Type -> OccValue -> m A.ArrayElem + renderArrayElem t (OccArray vs) + = do subT <- trivialSubscriptType m t + aes <- mapM (renderArrayElem subT) vs + return $ A.ArrayElemArray aes + renderArrayElem t v = renderValue m t v >>* A.ArrayElemExpr + + renderRecord :: (CSMR m, Die m) => [OccValue] -> m A.LiteralRepr + renderRecord vs + = do ts <- case t of + A.Infer -> return [A.Infer | _ <- vs] + _ -> recordFields m t >>* map snd + es <- sequence [renderValue m fieldT v | (fieldT, v) <- zip ts vs] + return $ A.RecordLiteral m es --}}}