diff options
Diffstat (limited to 'polly/lib/Support/ScopHelper.cpp')
-rw-r--r-- | polly/lib/Support/ScopHelper.cpp | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/polly/lib/Support/ScopHelper.cpp b/polly/lib/Support/ScopHelper.cpp index 0827dc0b08d..602009cd0df 100644 --- a/polly/lib/Support/ScopHelper.cpp +++ b/polly/lib/Support/ScopHelper.cpp @@ -17,12 +17,14 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/CFG.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; +using namespace polly; #define DEBUG_TYPE "polly-scop-helper" @@ -252,3 +254,110 @@ void polly::splitEntryBlockForAlloca(BasicBlock *EntryBlock, Pass *P) { // splitBlock updates DT, LI and RI. splitBlock(EntryBlock, I, DT, LI, RI); } + +/// The SCEVExpander will __not__ generate any code for an existing SDiv/SRem +/// instruction but just use it, if it is referenced as a SCEVUnknown. We want +/// however to generate new code if the instruction is in the analyzed region +/// and we generate code outside/in front of that region. Hence, we generate the +/// code for the SDiv/SRem operands in front of the analyzed region and then +/// create a new SDiv/SRem operation there too. +struct ScopExpander : SCEVVisitor<ScopExpander, const SCEV *> { + friend struct SCEVVisitor<ScopExpander, const SCEV *>; + + explicit ScopExpander(const Region &R, ScalarEvolution &SE, + const DataLayout &DL, const char *Name) + : Expander(SCEVExpander(SE, DL, Name)), SE(SE), Name(Name), R(R) {} + + Value *expandCodeFor(const SCEV *E, Type *Ty, Instruction *I) { + // If we generate code in the region we will immediately fall back to the + // SCEVExpander, otherwise we will stop at all unknowns in the SCEV and if + // needed replace them by copies computed in the entering block. + if (!R.contains(I)) + E = visit(E); + return Expander.expandCodeFor(E, Ty, I); + } + +private: + SCEVExpander Expander; + ScalarEvolution &SE; + const char *Name; + const Region &R; + + const SCEV *visitUnknown(const SCEVUnknown *E) { + Instruction *Inst = dyn_cast<Instruction>(E->getValue()); + if (!Inst || (Inst->getOpcode() != Instruction::SRem && + Inst->getOpcode() != Instruction::SDiv)) + return E; + + if (!R.contains(Inst)) + return E; + + Instruction *StartIP = R.getEnteringBlock()->getTerminator(); + + const SCEV *LHSScev = visit(SE.getSCEV(Inst->getOperand(0))); + const SCEV *RHSScev = visit(SE.getSCEV(Inst->getOperand(1))); + + Value *LHS = Expander.expandCodeFor(LHSScev, E->getType(), StartIP); + Value *RHS = Expander.expandCodeFor(RHSScev, E->getType(), StartIP); + + Inst = BinaryOperator::Create((Instruction::BinaryOps)Inst->getOpcode(), + LHS, RHS, Inst->getName() + Name, StartIP); + return SE.getSCEV(Inst); + } + + /// The following functions will just traverse the SCEV and rebuild it with + /// the new operands returned by the traversal. + /// + ///{ + const SCEV *visitConstant(const SCEVConstant *E) { return E; } + const SCEV *visitTruncateExpr(const SCEVTruncateExpr *E) { + return SE.getTruncateExpr(visit(E->getOperand()), E->getType()); + } + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *E) { + return SE.getZeroExtendExpr(visit(E->getOperand()), E->getType()); + } + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *E) { + return SE.getSignExtendExpr(visit(E->getOperand()), E->getType()); + } + const SCEV *visitUDivExpr(const SCEVUDivExpr *E) { + return SE.getUDivExpr(visit(E->getLHS()), visit(E->getRHS())); + } + const SCEV *visitAddExpr(const SCEVAddExpr *E) { + SmallVector<const SCEV *, 4> NewOps; + for (const SCEV *Op : E->operands()) + NewOps.push_back(visit(Op)); + return SE.getAddExpr(NewOps); + } + const SCEV *visitMulExpr(const SCEVMulExpr *E) { + SmallVector<const SCEV *, 4> NewOps; + for (const SCEV *Op : E->operands()) + NewOps.push_back(visit(Op)); + return SE.getMulExpr(NewOps); + } + const SCEV *visitUMaxExpr(const SCEVUMaxExpr *E) { + SmallVector<const SCEV *, 4> NewOps; + for (const SCEV *Op : E->operands()) + NewOps.push_back(visit(Op)); + return SE.getUMaxExpr(NewOps); + } + const SCEV *visitSMaxExpr(const SCEVSMaxExpr *E) { + SmallVector<const SCEV *, 4> NewOps; + for (const SCEV *Op : E->operands()) + NewOps.push_back(visit(Op)); + return SE.getSMaxExpr(NewOps); + } + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *E) { + SmallVector<const SCEV *, 4> NewOps; + for (const SCEV *Op : E->operands()) + NewOps.push_back(visit(Op)); + return SE.getAddRecExpr(NewOps, E->getLoop(), E->getNoWrapFlags()); + } + ///} +}; + +Value *polly::expandCodeFor(Scop &S, ScalarEvolution &SE, const DataLayout &DL, + const char *Name, const SCEV *E, Type *Ty, + Instruction *IP) { + ScopExpander Expander(S.getRegion(), SE, DL, Name); + return Expander.expandCodeFor(E, Ty, IP); +} |