diff --git a/checks/ArrayUsageCheck.hs b/checks/ArrayUsageCheck.hs index ba77ff8..a1176bb 100644 --- a/checks/ArrayUsageCheck.hs +++ b/checks/ArrayUsageCheck.hs @@ -56,7 +56,7 @@ import Utils -- Each list is a possible set of background knowledge mapping vars to a list -- of constraints. So it is a disjunction of map from variables to conjunctions type BK = [Map.Map Var [BackgroundKnowledge]] -type BK' = [Map.Map Var (EqualityProblem, InequalityProblem)] +type BK' = [Map.Map Var (Either String (EqualityProblem, InequalityProblem))] -- | Given a list of replicators, and a set of background knowledge for each -- access inside the replicator, checks if there are any solutions for a @@ -487,12 +487,13 @@ makeEquations accesses bound do ((accesses', allReps),repVars) <- flip runStateT [] $ parItemToArrayAccessM mkEq accesses high <- makeSingleEq id bound "upper bound" return (accesses', high, nub repVars, allReps) - return $ squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h) + squareAndPair (lookupBK allReps) (\(x,y,_) -> (x,y)) repVarIndexes s v (amap (const 0) h, addConstant (-1) h) where - lookupBK :: [A.Name] -> (A.Expression, [ModuloCase], BK') -> [(EqualityProblem, InequalityProblem)] - lookupBK reps (e,_,bk) = map (foldl accumProblem ([],[]) . map snd . - filter (\(v,eq) -> v `elem` vs || v `elem` reps') . Map.toList) bk + lookupBK :: [A.Name] -> (A.Expression, [ModuloCase], BK') -> Either String + [(EqualityProblem, InequalityProblem)] + lookupBK reps (e,_,bk) = mapM (foldl (liftM2 accumProblem) (return ([],[])) . map snd . + filter (\(v,_) -> v `elem` vs || v `elem` reps') . Map.toList) bk where reps' :: [Var] reps' = map (Var . A.Variable emptyMeta) reps @@ -683,22 +684,26 @@ flatten e = return [Scale 1 (canonicalise e,0)] -- will produce both "i = i' + 1" and "i + 1 = i'" so there is no need -- to vary the inequality itself. squareAndPair :: - (label -> [(EqualityProblem, InequalityProblem)]) -> + (label -> Either String [(EqualityProblem, InequalityProblem)]) -> (label -> labelStripped) -> [(CoeffIndex, CoeffIndex)] -> VarMap -> [ArrayAccess label] -> (EqualityConstraintEquation, EqualityConstraintEquation) -> - [((labelStripped, labelStripped), VarMap, (EqualityProblem, InequalityProblem))] + Either String [((labelStripped, labelStripped), VarMap, (EqualityProblem, InequalityProblem))] squareAndPair lookupBK strip repVars s v lh - = [(transformPair strip strip labels, s,squareEquations (nub (bkEqA ++ bkEqB) ++ - eq, nub (bkIneqA ++ bkIneqB) ++ ineq ++ concat (applyAll (eq,ineq) (map addExtra repVars)))) + = concatMapM id + [let f ((bkEqA, bkIneqA), (bkEqB, bkIneqB)) + = (transformPair strip strip labels, + s, + squareEquations (nub (bkEqA ++ bkEqB) ++ eq, + nub (bkIneqA ++ bkIneqB) ++ ineq ++ concat (applyAll (eq,ineq) (map addExtra repVars)))) + bk = case liftM2 (curry product2) (lookupBK (fst labels)) (lookupBK (snd labels)) of + Right [] -> Right [(([],[]),([],[]))] -- No BK + xs -> xs + in bk >>* map f | (labels, eq,ineq) <- pairEqsAndBounds v lh ,and (map (primeImpliesPlain (eq,ineq)) repVars) - ,((bkEqA, bkIneqA), (bkEqB, bkIneqB)) <- - case product2 (lookupBK (fst labels), lookupBK (snd labels)) of - [] -> [(([],[]),([],[]))] -- No BK - xs -> xs ] where itemPresent :: CoeffIndex -> [Array CoeffIndex Integer] -> Bool @@ -802,13 +807,20 @@ getIneqs (low, high) = concatMap getLH getLH :: EqualityConstraintEquation -> [InequalityConstraintEquation] getLH eq = [eq `addEq` (amap negate low),high `addEq` amap negate eq] +justState :: Error e => StateT s (Either e) a -> StateT s (Either e) (Either e a) +justState m = do st <- get + let (x, st') = case runStateT m st of + Left err -> (Left err, st) + Right (x, st') -> (Right x, st') + put st' + return x -- | Given an expression, forms equations (and accompanying additional equation-sets) and returns it makeEquation :: label -> (BK, [FlattenedExp] -> [FlattenedExp]) -> ArrayAccessType -> [FlattenedExp] -> StateT VarMap (Either String) (ArrayAccess (label,[ModuloCase], BK')) makeEquation l (bk, bkF) t summedItems = do eqs <- process summedItems - bk' <- mapM (mapMapM $ transformBKList bkF) bk + bk' <- mapM (mapMapM (justState . transformBKList bkF)) bk let eqs' = map (transformQuad id mapToArray (map mapToArray) (map mapToArray)) eqs :: [([ModuloCase], EqualityConstraintEquation, EqualityProblem, InequalityProblem)] return $ Group [((l,c,bk'),t,(e0,e1,e2)) | (c,e0,e1,e2) <- eqs'] where