diff options
Diffstat (limited to 'mlir/lib/Transforms')
| -rw-r--r-- | mlir/lib/Transforms/CSE.cpp | 1 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ComposeAffineMaps.cpp | 15 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ConstantFold.cpp | 11 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 90 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnroll.cpp | 49 | ||||
| -rw-r--r-- | mlir/lib/Transforms/LoopUnrollAndJam.cpp | 20 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 17 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/LoopUtils.cpp | 7 |
9 files changed, 88 insertions, 128 deletions
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 63a676d7b52..de10fe8a461 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -24,7 +24,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Transforms/ComposeAffineMaps.cpp b/mlir/lib/Transforms/ComposeAffineMaps.cpp index 289b00d3b51..796477c64f2 100644 --- a/mlir/lib/Transforms/ComposeAffineMaps.cpp +++ b/mlir/lib/Transforms/ComposeAffineMaps.cpp @@ -27,7 +27,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -46,10 +45,9 @@ namespace { // result of any AffineApplyOp). After this composition, AffineApplyOps with no // remaining uses are erased. // TODO(andydavis) Remove this when Chris adds instruction combiner pass. -struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> { +struct ComposeAffineMaps : public FunctionPass { explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst); SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps; @@ -68,15 +66,11 @@ static bool affineApplyOp(const Instruction &inst) { return inst.isa<AffineApplyOp>(); } -void ComposeAffineMaps::visitInstruction(Instruction *opInst) { - if (auto afOp = opInst->dyn_cast<AffineApplyOp>()) - affineApplyOps.push_back(afOp); -} - PassResult ComposeAffineMaps::runOnFunction(Function *f) { // If needed for future efficiency, reserve space based on a pre-walk. affineApplyOps.clear(); - walk(f); + f->walk<AffineApplyOp>( + [&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); }); for (auto afOp : affineApplyOps) { SmallVector<Value *, 8> operands(afOp->getOperands()); FuncBuilder b(afOp->getInstruction()); @@ -87,7 +81,8 @@ PassResult ComposeAffineMaps::runOnFunction(Function *f) { // Erase dead affine apply ops. affineApplyOps.clear(); - walk(f); + f->walk<AffineApplyOp>( + [&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); }); for (auto it = affineApplyOps.rbegin(); it != affineApplyOps.rend(); ++it) { if ((*it)->use_empty()) { (*it)->erase(); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 54486cdb293..e41ac0ad329 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -18,7 +18,6 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" @@ -27,7 +26,7 @@ using namespace mlir; namespace { /// Simple constant folding pass. -struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> { +struct ConstantFold : public FunctionPass { ConstantFold() : FunctionPass(&ConstantFold::passID) {} // All constants in the function post folding. @@ -35,9 +34,7 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> { // Operations that were folded and that need to be erased. std::vector<Instruction *> opInstsToErase; - bool foldOperation(Instruction *op, - SmallVectorImpl<Value *> &existingConstants); - void visitInstruction(Instruction *op); + void foldInstruction(Instruction *op); PassResult runOnFunction(Function *f) override; static char passID; @@ -49,7 +46,7 @@ char ConstantFold::passID = 0; /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -void ConstantFold::visitInstruction(Instruction *op) { +void ConstantFold::foldInstruction(Instruction *op) { // If this operation is an AffineForOp, then fold the bounds. if (auto forOp = op->dyn_cast<AffineForOp>()) { constantFoldBounds(forOp); @@ -111,7 +108,7 @@ PassResult ConstantFold::runOnFunction(Function *f) { existingConstants.clear(); opInstsToErase.clear(); - walk(f); + f->walk([&](Instruction *inst) { foldInstruction(inst); }); // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 7a002168528..77e5a6aa04f 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -28,7 +28,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -111,22 +110,23 @@ namespace { // LoopNestStateCollector walks loop nests and collects load and store // operations, and whether or not an IfInst was encountered in the loop nest. -class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> { -public: +struct LoopNestStateCollector { SmallVector<OpPointer<AffineForOp>, 4> forOps; SmallVector<Instruction *, 4> loadOpInsts; SmallVector<Instruction *, 4> storeOpInsts; bool hasNonForRegion = false; - void visitInstruction(Instruction *opInst) { - if (opInst->isa<AffineForOp>()) - forOps.push_back(opInst->cast<AffineForOp>()); - else if (opInst->getNumBlockLists() != 0) - hasNonForRegion = true; - else if (opInst->isa<LoadOp>()) - loadOpInsts.push_back(opInst); - else if (opInst->isa<StoreOp>()) - storeOpInsts.push_back(opInst); + void collect(Instruction *instToWalk) { + instToWalk->walk([&](Instruction *opInst) { + if (opInst->isa<AffineForOp>()) + forOps.push_back(opInst->cast<AffineForOp>()); + else if (opInst->getNumBlockLists() != 0) + hasNonForRegion = true; + else if (opInst->isa<LoadOp>()) + loadOpInsts.push_back(opInst); + else if (opInst->isa<StoreOp>()) + storeOpInsts.push_back(opInst); + }); } }; @@ -510,7 +510,7 @@ bool MemRefDependenceGraph::init(Function *f) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; - collector.walk(&inst); + collector.collect(&inst); // Return false if a non 'for' region was found (not currently supported). if (collector.hasNonForRegion) return false; @@ -606,41 +606,39 @@ struct LoopNestStats { // LoopNestStatsCollector walks a single loop nest and gathers per-loop // trip count and operation count statistics and records them in 'stats'. -class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> { -public: +struct LoopNestStatsCollector { LoopNestStats *stats; bool hasLoopWithNonConstTripCount = false; LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {} - void visitInstruction(Instruction *opInst) { - auto forOp = opInst->dyn_cast<AffineForOp>(); - if (!forOp) - return; - - auto *forInst = forOp->getInstruction(); - auto *parentInst = forOp->getInstruction()->getParentInst(); - if (parentInst != nullptr) { - assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp"); - // Add mapping to 'forOp' from its parent AffineForOp. - stats->loopMap[parentInst].push_back(forOp); - } + void collect(Instruction *inst) { + inst->walk<AffineForOp>([&](OpPointer<AffineForOp> forOp) { + auto *forInst = forOp->getInstruction(); + auto *parentInst = forOp->getInstruction()->getParentInst(); + if (parentInst != nullptr) { + assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp"); + // Add mapping to 'forOp' from its parent AffineForOp. + stats->loopMap[parentInst].push_back(forOp); + } - // Record the number of op instructions in the body of 'forOp'. - unsigned count = 0; - stats->opCountMap[forInst] = 0; - for (auto &inst : *forOp->getBody()) { - if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>())) - ++count; - } - stats->opCountMap[forInst] = count; - // Record trip count for 'forOp'. Set flag if trip count is not constant. - Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); - if (!maybeConstTripCount.hasValue()) { - hasLoopWithNonConstTripCount = true; - return; - } - stats->tripCountMap[forInst] = maybeConstTripCount.getValue(); + // Record the number of op instructions in the body of 'forOp'. + unsigned count = 0; + stats->opCountMap[forInst] = 0; + for (auto &inst : *forOp->getBody()) { + if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>())) + ++count; + } + stats->opCountMap[forInst] = count; + // Record trip count for 'forOp'. Set flag if trip count is not + // constant. + Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); + if (!maybeConstTripCount.hasValue()) { + hasLoopWithNonConstTripCount = true; + return; + } + stats->tripCountMap[forInst] = maybeConstTripCount.getValue(); + }); } }; @@ -1078,7 +1076,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats); - srcStatsCollector.walk(srcLoopIVs[0]->getInstruction()); + srcStatsCollector.collect(srcLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (srcStatsCollector.hasLoopWithNonConstTripCount) return false; @@ -1089,7 +1087,7 @@ static bool isFusionProfitable(Instruction *srcOpInst, LoopNestStats dstLoopNestStats; LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats); - dstStatsCollector.walk(dstLoopIVs[0]->getInstruction()); + dstStatsCollector.collect(dstLoopIVs[0]->getInstruction()); // Currently only constant trip count loop nests are supported. if (dstStatsCollector.hasLoopWithNonConstTripCount) return false; @@ -1474,7 +1472,7 @@ public: // Collect slice loop stats. LoopNestStateCollector sliceCollector; - sliceCollector.walk(sliceLoopNest->getInstruction()); + sliceCollector.collect(sliceLoopNest->getInstruction()); // Promote single iteration slice loops to single IV value. for (auto forOp : sliceCollector.forOps) { promoteIfSingleIteration(forOp); @@ -1498,7 +1496,7 @@ public: // Collect dst loop stats after memref privatizaton transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.walk(dstAffineForOp->getInstruction()); + dstLoopCollector.collect(dstAffineForOp->getInstruction()); // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index b1e15ccb07b..3a7cfb85e08 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -27,7 +27,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -95,15 +94,16 @@ char LoopUnroll::passID = 0; PassResult LoopUnroll::runOnFunction(Function *f) { // Gathers all innermost loops through a post order pruned walk. - class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> { - public: + struct InnermostLoopGatherer { // Store innermost loops as we walk. std::vector<OpPointer<AffineForOp>> loops; - // This method specialized to encode custom return logic. - using InstListType = llvm::iplist<Instruction>; - bool walkPostOrder(InstListType::iterator Start, - InstListType::iterator End) { + void walkPostOrder(Function *f) { + for (auto &b : *f) + walkPostOrder(b.begin(), b.end()); + } + + bool walkPostOrder(Block::iterator Start, Block::iterator End) { bool hasInnerLoops = false; // We need to walk all elements since all innermost loops need to be // gathered as opposed to determining whether this list has any inner @@ -112,7 +112,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) { hasInnerLoops |= walkPostOrder(&(*Start++)); return hasInnerLoops; } - bool walkPostOrder(Instruction *opInst) { bool hasInnerLoops = false; for (auto &blockList : opInst->getBlockLists()) @@ -125,39 +124,21 @@ PassResult LoopUnroll::runOnFunction(Function *f) { } return hasInnerLoops; } - - // FIXME: can't use base class method for this because that in turn would - // need to use the derived class method above. CRTP doesn't allow it, and - // the compiler error resulting from it is also misleading. - using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder; }; - // Gathers all loops with trip count <= minTripCount. - class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> { - public: + if (clUnrollFull.getNumOccurrences() > 0 && + clUnrollFullThreshold.getNumOccurrences() > 0) { // Store short loops as we walk. std::vector<OpPointer<AffineForOp>> loops; - const unsigned minTripCount; - ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {} - void visitInstruction(Instruction *opInst) { - auto forOp = opInst->dyn_cast<AffineForOp>(); - if (!forOp) - return; + // Gathers all loops with trip count <= minTripCount. Do a post order walk + // so that loops are gathered from innermost to outermost (or else unrolling + // an outer one may delete gathered inner ones). + f->walkPostOrder<AffineForOp>([&](OpPointer<AffineForOp> forOp) { Optional<uint64_t> tripCount = getConstantTripCount(forOp); - if (tripCount.hasValue() && tripCount.getValue() <= minTripCount) + if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) loops.push_back(forOp); - } - }; - - if (clUnrollFull.getNumOccurrences() > 0 && - clUnrollFullThreshold.getNumOccurrences() > 0) { - ShortLoopGatherer slg(clUnrollFullThreshold); - // Do a post order walk so that loops are gathered from innermost to - // outermost (or else unrolling an outer one may delete gathered inner - // ones). - slg.walkPostOrder(f); - auto &loops = slg.loops; + }); for (auto forOp : loops) loopUnrollFull(forOp); return success(); diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 74c54fde047..b2aed7d9d7f 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -50,7 +50,6 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/DenseMap.h" @@ -136,24 +135,25 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp, // Gathers all maximal sub-blocks of instructions that do not themselves // include a for inst (a instruction could have a descendant for inst though // in its tree). - class JamBlockGatherer : public InstWalker<JamBlockGatherer> { - public: - using InstListType = llvm::iplist<Instruction>; - using InstWalker<JamBlockGatherer>::walk; - + struct JamBlockGatherer { // Store iterators to the first and last inst of each sub-block found. std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks; // This is a linear time walk. - void walk(InstListType::iterator Start, InstListType::iterator End) { - for (auto it = Start; it != End;) { + void walk(Instruction *inst) { + for (auto &blockList : inst->getBlockLists()) + for (auto &block : blockList) + walk(block); + } + void walk(Block &block) { + for (auto it = block.begin(), e = block.end(); it != e;) { auto subBlockStart = it; - while (it != End && !it->isa<AffineForOp>()) + while (it != e && !it->isa<AffineForOp>()) ++it; if (it != subBlockStart) subBlocks.push_back({subBlockStart, std::prev(it)}); // Process all for insts that appear next. - while (it != End && it->isa<AffineForOp>()) + while (it != e && it->isa<AffineForOp>()) walk(&*it++); } } diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 2d06a327315..9c9db30d163 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -25,7 +25,6 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/Passes.h" @@ -70,12 +69,12 @@ namespace { // currently only eliminates the stores only if no other loads/uses (other // than dealloc) remain. // -struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> { +struct MemRefDataFlowOpt : public FunctionPass { explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {} PassResult runOnFunction(Function *f) override; - void visitInstruction(Instruction *opInst); + void forwardStoreToLoad(OpPointer<LoadOp> loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet<Value *, 4> memrefsToErase; @@ -100,14 +99,9 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() { // This is a straightforward implementation not optimized for speed. Optimize // this in the future if needed. -void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) { +void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer<LoadOp> loadOp) { Instruction *lastWriteStoreOp = nullptr; - - auto loadOp = opInst->dyn_cast<LoadOp>(); - if (!loadOp) - return; - - Instruction *loadOpInst = opInst; + Instruction *loadOpInst = loadOp->getInstruction(); // First pass over the use list to get minimum number of surrounding // loops common between the load op and the store op, with min taken across @@ -235,7 +229,8 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - walk(f); + f->walk<LoadOp>( + [&](OpPointer<LoadOp> loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index ba3be5e95f4..4ca48a53485 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -142,10 +142,8 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - f->walkPostOrder([&](Instruction *opInst) { - if (auto forOp = opInst->dyn_cast<AffineForOp>()) - forOps.push_back(forOp); - }); + f->walkPostOrder<AffineForOp>( + [&](OpPointer<AffineForOp> forOp) { forOps.push_back(forOp); }); bool ret = false; for (auto forOp : forOps) { ret = ret | runOnAffineForOp(forOp); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 5bf17989bef..95875adca6e 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -28,7 +28,6 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/IR/Instruction.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" @@ -135,10 +134,8 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) { /// their body into the containing Block. void mlir::promoteSingleIterationLoops(Function *f) { // Gathers all innermost loops through a post order pruned walk. - f->walkPostOrder([](Instruction *inst) { - if (auto forOp = inst->dyn_cast<AffineForOp>()) - promoteIfSingleIteration(forOp); - }); + f->walkPostOrder<AffineForOp>( + [](OpPointer<AffineForOp> forOp) { promoteIfSingleIteration(forOp); }); } /// Generates a 'for' inst with the specified lower and upper bounds while |

