diff options
Diffstat (limited to 'mlir/lib')
25 files changed, 163 insertions, 258 deletions
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp index 8a0cb44f0cc..a6730f01199 100644 --- a/mlir/lib/Analysis/MemRefBoundCheck.cpp +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -37,22 +37,18 @@ using namespace mlir; namespace { /// Checks for out of bound memef access subscripts.. -struct MemRefBoundCheck : public FunctionPass { - explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct MemRefBoundCheck : public FunctionPass<MemRefBoundCheck> { + PassResult runOnFunction() override; }; } // end anonymous namespace -FunctionPass *mlir::createMemRefBoundCheckPass() { +FunctionPassBase *mlir::createMemRefBoundCheckPass() { return new MemRefBoundCheck(); } -PassResult MemRefBoundCheck::runOnFunction(Function *f) { - f->walk([](Instruction *opInst) { +PassResult MemRefBoundCheck::runOnFunction() { + getFunction().walk([](Instruction *opInst) { if (auto loadOp = opInst->dyn_cast<LoadOp>()) { boundCheckLoadOrStoreOp(loadOp); } else if (auto storeOp = opInst->dyn_cast<StoreOp>()) { diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp index 93d4fde1fd9..33488f0c7a8 100644 --- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp +++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp @@ -37,19 +37,14 @@ namespace { // TODO(andydavis) Add common surrounding loop depth-wise dependence checks. /// Checks dependences between all pairs of memref accesses in a Function. -struct MemRefDependenceCheck : public FunctionPass { +struct MemRefDependenceCheck : public FunctionPass<MemRefDependenceCheck> { SmallVector<Instruction *, 4> loadsAndStores; - explicit MemRefDependenceCheck() - : FunctionPass(&MemRefDependenceCheck::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; + PassResult runOnFunction() override; }; } // end anonymous namespace -FunctionPass *mlir::createMemRefDependenceCheckPass() { +FunctionPassBase *mlir::createMemRefDependenceCheckPass() { return new MemRefDependenceCheck(); } @@ -116,10 +111,10 @@ static void checkDependences(ArrayRef<Instruction *> loadsAndStores) { // Walks the Function 'f' adding load and store ops to 'loadsAndStores'. // Runs pair-wise dependence checks. -PassResult MemRefDependenceCheck::runOnFunction(Function *f) { +PassResult MemRefDependenceCheck::runOnFunction() { // Collect the loads and stores within the function. loadsAndStores.clear(); - f->walk([&](Instruction *inst) { + getFunction().walk([&](Instruction *inst) { if (inst->isa<LoadOp>() || inst->isa<StoreOp>()) loadsAndStores.push_back(inst); }); diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index c1fcacac15a..a17be9d176b 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -26,29 +26,26 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public ModulePass { - explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) - : ModulePass(&PrintOpStatsPass::passID), os(os) {} +struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> { + explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs()) : os(os) {} // Prints the resultant operation statistics post iterating over the module. - PassResult runOnModule(Module *m) override; + PassResult runOnModule() override; // Print summary of op stats. void printSummary(); - constexpr static PassID passID = {}; - private: llvm::StringMap<int64_t> opCount; llvm::raw_ostream &os; }; } // namespace -PassResult PrintOpStatsPass::runOnModule(Module *m) { +PassResult PrintOpStatsPass::runOnModule() { opCount.clear(); // Compute the operation statistics for each function in the module. - for (auto &fn : *m) + for (auto &fn : getModule()) fn.walk( [&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; }); printSummary(); diff --git a/mlir/lib/EDSC/LowerEDSCTestPass.cpp b/mlir/lib/EDSC/LowerEDSCTestPass.cpp index cdbe6e52ec0..2b6c38bf8c6 100644 --- a/mlir/lib/EDSC/LowerEDSCTestPass.cpp +++ b/mlir/lib/EDSC/LowerEDSCTestPass.cpp @@ -33,18 +33,15 @@ using namespace mlir; namespace { // Testing pass to lower EDSC. -struct LowerEDSCTestPass : public FunctionPass { - LowerEDSCTestPass() : FunctionPass(&LowerEDSCTestPass::passID) {} - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct LowerEDSCTestPass : public FunctionPass<LowerEDSCTestPass> { + PassResult runOnFunction() override; }; } // end anonymous namespace #include "mlir/EDSC/reference-impl.inc" -PassResult LowerEDSCTestPass::runOnFunction(Function *f) { - f->walk([](Instruction *op) { +PassResult LowerEDSCTestPass::runOnFunction() { + getFunction().walk([](Instruction *op) { if (op->getName().getStringRef() == "print") { auto opName = op->getAttrOfType<StringAttr>("op"); if (!opName) { diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 01278aad8af..1a3dd6ffff0 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -22,7 +22,7 @@ #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" -#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR.h" #include "mlir/Transforms/Passes.h" @@ -168,35 +168,20 @@ static inline Error make_string_error(const llvm::Twine &message) { // - CSE // - canonicalization // - affine lowering -static std::vector<std::unique_ptr<mlir::Pass>> -getDefaultPasses(const std::vector<const mlir::PassInfo *> &mlirPassInfoList) { - std::vector<std::unique_ptr<mlir::Pass>> passList; - passList.reserve(mlirPassInfoList.size() + 4); +static void +getDefaultPasses(PassManager &manager, + const std::vector<const mlir::PassInfo *> &mlirPassInfoList) { // Run each of the passes that were selected. for (const auto *passInfo : mlirPassInfoList) { - passList.emplace_back(passInfo->createPass()); + manager.addPass(passInfo->createPass()); } - // Append the extra passes for lowering to MLIR. - passList.emplace_back(mlir::createConstantFoldPass()); - passList.emplace_back(mlir::createCSEPass()); - passList.emplace_back(mlir::createCanonicalizerPass()); - passList.emplace_back(mlir::createLowerAffinePass()); - passList.emplace_back(mlir::createConvertToLLVMIRPass()); - return passList; -} -// Run the passes sequentially on the given module. -// Return `nullptr` immediately if any of the passes fails. -static bool runPasses(const std::vector<std::unique_ptr<mlir::Pass>> &passes, - Module *module) { - for (const auto &pass : passes) { - mlir::PassResult result = pass->runOnModule(module); - if (result == mlir::PassResult::Failure || module->verify()) { - llvm::errs() << "Pass failed\n"; - return true; - } - } - return false; + // Append the extra passes for lowering to MLIR. + manager.addPass(mlir::createConstantFoldPass()); + manager.addPass(mlir::createCSEPass()); + manager.addPass(mlir::createCanonicalizerPass()); + manager.addPass(mlir::createLowerAffinePass()); + manager.addPass(mlir::createConvertToLLVMIRPass()); } // Setup LLVM target triple from the current machine. @@ -295,7 +280,10 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create( if (!expectedJIT) return expectedJIT.takeError(); - if (runPasses(getDefaultPasses({}), m)) + // Construct and run the default MLIR pipeline. + PassManager manager; + getDefaultPasses(manager, {}); + if (manager.run(m)) return make_string_error("passes failed"); auto llvmModule = translateModuleToLLVMIR(*m); diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 7421ebbeaaa..64ee5862ae7 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -1137,13 +1137,10 @@ static void ensureDistinctSuccessors(Module *m) { /// A pass converting MLIR Standard and Builtin operations into the LLVM IR /// dialect. -class LLVMLowering : public ModulePass, public DialectConversion { +class LLVMLowering : public ModulePass<LLVMLowering>, public DialectConversion { public: - LLVMLowering() : ModulePass(&passID) {} - - constexpr static PassID passID = {}; - - PassResult runOnModule(Module *m) override { + PassResult runOnModule() override { + Module *m = &getModule(); uniqueSuccessorsWithArguments(m); return DialectConversion::convert(m) ? failure() : success(); } @@ -1203,7 +1200,7 @@ private: llvm::Module *module; }; -ModulePass *mlir::createConvertToLLVMIRPass() { return new LLVMLowering; } +ModulePassBase *mlir::createConvertToLLVMIRPass() { return new LLVMLowering(); } static PassRegistration<LLVMLowering> pass("convert-to-llvmir", "Convert all functions to the LLVM IR dialect"); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index b652f1f700b..c05c9d24aa4 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -29,24 +29,6 @@ using namespace mlir; /// single .o file. void Pass::anchor() {} -/// Out of line virtual method to ensure vtables and metadata are emitted to a -/// single .o file. -void ModulePass::anchor() {} - -/// Function passes walk a module and look at each function with their -/// corresponding hooks and terminates upon error encountered. -PassResult FunctionPass::runOnModule(Module *m) { - for (auto &fn : *m) { - // All function passes ignore external functions. - if (fn.isExternal()) - continue; - - if (runOnFunction(&fn)) - return failure(); - } - return success(); -} - /// Forwarding function to execute this pass. PassResult FunctionPassBase::run(Function *fn) { /// Initialize the pass state. diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index fee9d5a3828..24b53220613 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -80,11 +80,7 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Instruction *> { namespace { /// Simple common sub-expression elimination. -struct CSE : public FunctionPass { - CSE() : FunctionPass(&CSE::passID) {} - - constexpr static PassID passID = {}; - +struct CSE : public FunctionPass<CSE> { /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, @@ -115,7 +111,7 @@ struct CSE : public FunctionPass { void simplifyBlock(DominanceInfo &domInfo, Block *bb); void simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList); - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; private: /// A scoped hash table of defining operations within a function. @@ -220,9 +216,9 @@ void CSE::simplifyBlockList(DominanceInfo &domInfo, BlockList &blockList) { } } -PassResult CSE::runOnFunction(Function *f) { - DominanceInfo domInfo(f); - simplifyBlockList(domInfo, f->getBlockList()); +PassResult CSE::runOnFunction() { + DominanceInfo domInfo(&getFunction()); + simplifyBlockList(domInfo, getFunction().getBlockList()); /// Erase any operations that were marked as dead during simplification. for (auto *op : opsToErase) @@ -232,7 +228,7 @@ PassResult CSE::runOnFunction(Function *f) { return success(); } -FunctionPass *mlir::createCSEPass() { return new CSE(); } +FunctionPassBase *mlir::createCSEPass() { return new CSE(); } static PassRegistration<CSE> pass("cse", "Eliminate common sub-expressions in functions"); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index ac77e201acf..764f055a673 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -33,30 +33,30 @@ using namespace mlir; namespace { /// Canonicalize operations in functions. -struct Canonicalizer : public FunctionPass { - Canonicalizer() : FunctionPass(&Canonicalizer::passID) {} - PassResult runOnFunction(Function *fn) override; - - constexpr static PassID passID = {}; +struct Canonicalizer : public FunctionPass<Canonicalizer> { + PassResult runOnFunction() override; }; } // end anonymous namespace -PassResult Canonicalizer::runOnFunction(Function *fn) { - auto *context = fn->getContext(); +PassResult Canonicalizer::runOnFunction() { OwningRewritePatternList patterns; + auto &func = getFunction(); // TODO: Instead of adding all known patterns from the whole system lazily add // and cache the canonicalization patterns for ops we see in practice when // building the worklist. For now, we just grab everything. - for (auto *op : fn->getContext()->getRegisteredOperations()) + auto *context = func.getContext(); + for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(patterns, context); - applyPatternsGreedily(fn, std::move(patterns)); + applyPatternsGreedily(&func, std::move(patterns)); return success(); } /// Create a Canonicalizer pass. -FunctionPass *mlir::createCanonicalizerPass() { return new Canonicalizer(); } +FunctionPassBase *mlir::createCanonicalizerPass() { + return new Canonicalizer(); +} static PassRegistration<Canonicalizer> pass("canonicalize", "Canonicalize operations"); diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index 4817baaa23e..ed35c03755f 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -26,18 +26,14 @@ using namespace mlir; namespace { /// Simple constant folding pass. -struct ConstantFold : public FunctionPass { - ConstantFold() : FunctionPass(&ConstantFold::passID) {} - +struct ConstantFold : public FunctionPass<ConstantFold> { // All constants in the function post folding. SmallVector<Value *, 8> existingConstants; // Operations that were folded and that need to be erased. std::vector<Instruction *> opInstsToErase; void foldInstruction(Instruction *op); - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; + PassResult runOnFunction() override; }; } // end anonymous namespace @@ -96,11 +92,11 @@ void ConstantFold::foldInstruction(Instruction *op) { // For now, we do a simple top-down pass over a function folding constants. We // don't handle conditional control flow, block arguments, folding // conditional branches, or anything else fancy. -PassResult ConstantFold::runOnFunction(Function *f) { +PassResult ConstantFold::runOnFunction() { existingConstants.clear(); opInstsToErase.clear(); - f->walk([&](Instruction *inst) { foldInstruction(inst); }); + getFunction().walk([&](Instruction *inst) { foldInstruction(inst); }); // At this point, these operations are dead, remove them. // TODO: This is assuming that all constant foldable operations have no @@ -122,7 +118,7 @@ PassResult ConstantFold::runOnFunction(Function *f) { } /// Creates a constant folding pass. -FunctionPass *mlir::createConstantFoldPass() { return new ConstantFold(); } +FunctionPassBase *mlir::createConstantFoldPass() { return new ConstantFold(); } static PassRegistration<ConstantFold> pass("constant-fold", "Constant fold operations in functions"); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 4fb6f34ed53..82ba07acb5f 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -74,18 +74,17 @@ namespace { /// memory capacity provided. // TODO(bondhugula): We currently can't generate DMAs correctly when stores are // strided. Check for strided stores. -struct DmaGeneration : public FunctionPass { +struct DmaGeneration : public FunctionPass<DmaGeneration> { explicit DmaGeneration(unsigned slowMemorySpace = 0, unsigned fastMemorySpace = clFastMemorySpace, int minDmaTransferSize = 1024, uint64_t fastMemCapacityBytes = clFastMemoryCapacity * 1024) - : FunctionPass(&DmaGeneration::passID), slowMemorySpace(slowMemorySpace), - fastMemorySpace(fastMemorySpace), + : slowMemorySpace(slowMemorySpace), fastMemorySpace(fastMemorySpace), minDmaTransferSize(minDmaTransferSize), fastMemCapacityBytes(fastMemCapacityBytes) {} - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; bool runOnBlock(Block *block); uint64_t runOnBlock(Block::iterator begin, Block::iterator end); @@ -115,8 +114,6 @@ struct DmaGeneration : public FunctionPass { // Constant zero index to avoid too many duplicates. Value *zeroIndex = nullptr; - - constexpr static PassID passID = {}; }; } // end anonymous namespace @@ -125,10 +122,10 @@ struct DmaGeneration : public FunctionPass { /// buffers in 'fastMemorySpace', and replaces memory operations to the former /// by the latter. Only load op's handled for now. /// TODO(bondhugula): extend this to store op's. -FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace, - unsigned fastMemorySpace, - int minDmaTransferSize, - uint64_t fastMemCapacityBytes) { +FunctionPassBase *mlir::createDmaGenerationPass(unsigned slowMemorySpace, + unsigned fastMemorySpace, + int minDmaTransferSize, + uint64_t fastMemCapacityBytes) { return new DmaGeneration(slowMemorySpace, fastMemorySpace, minDmaTransferSize, fastMemCapacityBytes); } @@ -757,7 +754,8 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { return totalDmaBuffersSizeInBytes; } -PassResult DmaGeneration::runOnFunction(Function *f) { +PassResult DmaGeneration::runOnFunction() { + Function *f = &getFunction(); FuncBuilder topBuilder(f); zeroIndex = topBuilder.create<ConstantIndexOp>(f->getLoc(), 0); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 0f4e45c372a..1528e394506 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -86,14 +86,12 @@ namespace { // TODO(andydavis) Extend this pass to check for fusion preventing dependences, // and add support for more general loop fusion algorithms. -struct LoopFusion : public FunctionPass { +struct LoopFusion : public FunctionPass<LoopFusion> { LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0) - : FunctionPass(&LoopFusion::passID), - localBufSizeThreshold(localBufSizeThreshold), + : localBufSizeThreshold(localBufSizeThreshold), fastMemorySpace(fastMemorySpace) {} - PassResult runOnFunction(Function *f) override; - constexpr static PassID passID = {}; + PassResult runOnFunction() override; // Any local buffers smaller than this size (in bytes) will be created in // `fastMemorySpace` if provided. @@ -107,8 +105,8 @@ struct LoopFusion : public FunctionPass { } // end anonymous namespace -FunctionPass *mlir::createLoopFusionPass(unsigned fastMemorySpace, - uint64_t localBufSizeThreshold) { +FunctionPassBase *mlir::createLoopFusionPass(unsigned fastMemorySpace, + uint64_t localBufSizeThreshold) { return new LoopFusion(fastMemorySpace, localBufSizeThreshold); } @@ -1802,7 +1800,7 @@ public: } // end anonymous namespace -PassResult LoopFusion::runOnFunction(Function *f) { +PassResult LoopFusion::runOnFunction() { // Override if a command line argument was provided. if (clFusionFastMemorySpace.getNumOccurrences() > 0) { fastMemorySpace = clFusionFastMemorySpace.getValue(); @@ -1814,7 +1812,7 @@ PassResult LoopFusion::runOnFunction(Function *f) { } MemRefDependenceGraph g; - if (g.init(f)) + if (g.init(&getFunction())) GreedyFusion(&g).run(localBufSizeThreshold, fastMemorySpace); return success(); } diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 240b2b6d9b6..db0e8d51ad8 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -47,12 +47,10 @@ static llvm::cl::list<unsigned> clTileSizes( namespace { /// A pass to perform loop tiling on all suitable loop nests of a Function. -struct LoopTiling : public FunctionPass { - LoopTiling() : FunctionPass(&LoopTiling::passID) {} - PassResult runOnFunction(Function *f) override; +struct LoopTiling : public FunctionPass<LoopTiling> { + PassResult runOnFunction() override; constexpr static unsigned kDefaultTileSize = 4; - constexpr static PassID passID = {}; }; } // end anonymous namespace @@ -65,7 +63,7 @@ static llvm::cl::opt<unsigned> /// Creates a pass to perform loop tiling on all suitable loop nests of an /// Function. -FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); } +FunctionPassBase *mlir::createLoopTilingPass() { return new LoopTiling(); } // Move the loop body of AffineForOp 'src' from 'src' into the specified // location in destination's body. @@ -255,9 +253,9 @@ getTileableBands(Function *f, getMaximalPerfectLoopNest(forOp); } -PassResult LoopTiling::runOnFunction(Function *f) { +PassResult LoopTiling::runOnFunction() { std::vector<SmallVector<OpPointer<AffineForOp>, 6>> bands; - getTileableBands(f, &bands); + getTileableBands(&getFunction(), &bands); for (auto &band : bands) { // Set up tile sizes; fill missing tile sizes at the end with default tile diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 3b4a0517f0d..231dba65720 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -65,7 +65,7 @@ namespace { /// full unroll threshold was specified, in which case, fully unrolls all loops /// with trip count less than the specified threshold. The latter is for testing /// purposes, especially for testing outer loop unrolling. -struct LoopUnroll : public FunctionPass { +struct LoopUnroll : public FunctionPass<LoopUnroll> { const Optional<unsigned> unrollFactor; const Optional<bool> unrollFull; // Callback to obtain unroll factors; if this has a callable target, takes @@ -76,21 +76,19 @@ struct LoopUnroll : public FunctionPass { Optional<bool> unrollFull = None, const std::function<unsigned(ConstOpPointer<AffineForOp>)> &getUnrollFactor = nullptr) - : FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor), - unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {} + : unrollFactor(unrollFactor), unrollFull(unrollFull), + getUnrollFactor(getUnrollFactor) {} - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; /// Unroll this for inst. Returns false if nothing was done. bool runOnAffineForOp(OpPointer<AffineForOp> forOp); static const unsigned kDefaultUnrollFactor = 4; - - constexpr static PassID passID = {}; }; } // end anonymous namespace -PassResult LoopUnroll::runOnFunction(Function *f) { +PassResult LoopUnroll::runOnFunction() { // Gathers all innermost loops through a post order pruned walk. struct InnermostLoopGatherer { // Store innermost loops as we walk. @@ -132,7 +130,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) { // Gathers all loops with trip count <= minTripCount. Do a post order walk // so that loops are gathered from innermost to outermost (or else unrolling // an outer one may delete gathered inner ones). - f->walkPostOrder<AffineForOp>([&](OpPointer<AffineForOp> forOp) { + getFunction().walkPostOrder<AffineForOp>([&](OpPointer<AffineForOp> forOp) { Optional<uint64_t> tripCount = getConstantTripCount(forOp); if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) loops.push_back(forOp); @@ -146,9 +144,10 @@ PassResult LoopUnroll::runOnFunction(Function *f) { ? clUnrollNumRepetitions : 1; // If the call back is provided, we will recurse until no loops are found. + Function *func = &getFunction(); for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { InnermostLoopGatherer ilg; - ilg.walkPostOrder(f); + ilg.walkPostOrder(func); auto &loops = ilg.loops; if (loops.empty()) break; @@ -184,7 +183,7 @@ bool LoopUnroll::runOnAffineForOp(OpPointer<AffineForOp> forOp) { return loopUnrollByFactor(forOp, kDefaultUnrollFactor); } -FunctionPass *mlir::createLoopUnrollPass( +FunctionPassBase *mlir::createLoopUnrollPass( int unrollFactor, int unrollFull, const std::function<unsigned(ConstOpPointer<AffineForOp>)> &getUnrollFactor) { diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 87e2770aa41..e950d117ddc 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -71,34 +71,30 @@ static llvm::cl::opt<unsigned> namespace { /// Loop unroll jam pass. Currently, this just unroll jams the first /// outer loop in a Function. -struct LoopUnrollAndJam : public FunctionPass { +struct LoopUnrollAndJam : public FunctionPass<LoopUnrollAndJam> { Optional<unsigned> unrollJamFactor; static const unsigned kDefaultUnrollJamFactor = 4; explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor = None) - : FunctionPass(&LoopUnrollAndJam::passID), - unrollJamFactor(unrollJamFactor) {} + : unrollJamFactor(unrollJamFactor) {} - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; bool runOnAffineForOp(OpPointer<AffineForOp> forOp); - - constexpr static PassID passID = {}; }; } // end anonymous namespace -FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { +FunctionPassBase *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { return new LoopUnrollAndJam( unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor)); } -PassResult LoopUnrollAndJam::runOnFunction(Function *f) { +PassResult LoopUnrollAndJam::runOnFunction() { // Currently, just the outermost loop from the first loop nest is // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on - // any for Inst. - auto &entryBlock = f->front(); - if (!entryBlock.empty()) - if (auto forOp = entryBlock.front().dyn_cast<AffineForOp>()) - runOnAffineForOp(forOp); + // any for operation. + auto &entryBlock = getFunction().front(); + if (auto forOp = entryBlock.front().dyn_cast<AffineForOp>()) + runOnAffineForOp(forOp); return success(); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 83620516994..aecd4314d42 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -242,16 +242,12 @@ Optional<SmallVector<Value *, 8>> static expandAffineMap( } namespace { -class LowerAffinePass : public FunctionPass { -public: - LowerAffinePass() : FunctionPass(&passID) {} - PassResult runOnFunction(Function *function) override; +struct LowerAffinePass : public FunctionPass<LowerAffinePass> { + PassResult runOnFunction() override; bool lowerAffineFor(OpPointer<AffineForOp> forOp); bool lowerAffineIf(AffineIfOp *ifOp); bool lowerAffineApply(AffineApplyOp *op); - - constexpr static PassID passID = {}; }; } // end anonymous namespace @@ -608,12 +604,12 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { // construction. When an Value is used, it gets replaced with the // corresponding Value that has been defined previously. The value flow // starts with function arguments converted to basic block arguments. -PassResult LowerAffinePass::runOnFunction(Function *function) { +PassResult LowerAffinePass::runOnFunction() { SmallVector<Instruction *, 8> instsToRewrite; // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. // We do this as a prepass to avoid invalidating the walker with our rewrite. - function->walk([&](Instruction *inst) { + getFunction().walk([&](Instruction *inst) { if (inst->isa<AffineApplyOp>() || inst->isa<AffineForOp>() || inst->isa<AffineIfOp>()) instsToRewrite.push_back(inst); @@ -638,7 +634,9 @@ PassResult LowerAffinePass::runOnFunction(Function *function) { /// Lowers If and For instructions within a function into their lower level CFG /// equivalent blocks. -FunctionPass *mlir::createLowerAffinePass() { return new LowerAffinePass(); } +FunctionPassBase *mlir::createLowerAffinePass() { + return new LowerAffinePass(); +} static PassRegistration<LowerAffinePass> pass("lower-affine", diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 61f75ae76e6..ddeb524f5ab 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -424,25 +424,22 @@ public: } }; -struct LowerVectorTransfersPass : public FunctionPass { - LowerVectorTransfersPass() - : FunctionPass(&LowerVectorTransfersPass::passID) {} - - PassResult runOnFunction(Function *fn) override { +struct LowerVectorTransfersPass + : public FunctionPass<LowerVectorTransfersPass> { + PassResult runOnFunction() { + Function *f = &getFunction(); applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>, - VectorTransferExpander<VectorTransferWriteOp>>(fn); + VectorTransferExpander<VectorTransferWriteOp>>(f); return success(); } // Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit. edsc::ScopedEDSCContext raiiContext; - - constexpr static PassID passID = {}; }; } // end anonymous namespace -FunctionPass *mlir::createLowerVectorTransfersPass() { +FunctionPassBase *mlir::createLowerVectorTransfersPass() { return new LowerVectorTransfersPass(); } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 6177ca1233b..7b45af011ab 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -196,12 +196,8 @@ struct MaterializationState { DenseMap<const Value *, Value *> *substitutionsMap; }; -struct MaterializeVectorsPass : public FunctionPass { - MaterializeVectorsPass() : FunctionPass(&MaterializeVectorsPass::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct MaterializeVectorsPass : public FunctionPass<MaterializeVectorsPass> { + PassResult runOnFunction() override; }; } // end anonymous namespace @@ -733,11 +729,12 @@ static bool materialize(Function *f, return false; } -PassResult MaterializeVectorsPass::runOnFunction(Function *f) { +PassResult MaterializeVectorsPass::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. + Function *f = &getFunction(); if (f->getBlocks().size() != 1) return success(); @@ -771,7 +768,7 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) { return fail ? PassResult::Failure : PassResult::Success; } -FunctionPass *mlir::createMaterializeVectorsPass() { +FunctionPassBase *mlir::createMaterializeVectorsPass() { return new MaterializeVectorsPass(); } diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 0ba06fecae0..067bfa4c94c 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -69,10 +69,8 @@ namespace { // currently only eliminates the stores only if no other loads/uses (other // than dealloc) remain. // -struct MemRefDataFlowOpt : public FunctionPass { - explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {} - - PassResult runOnFunction(Function *f) override; +struct MemRefDataFlowOpt : public FunctionPass<MemRefDataFlowOpt> { + PassResult runOnFunction() override; void forwardStoreToLoad(OpPointer<LoadOp> loadOp); @@ -83,15 +81,13 @@ struct MemRefDataFlowOpt : public FunctionPass { DominanceInfo *domInfo = nullptr; PostDominanceInfo *postDomInfo = nullptr; - - constexpr static PassID passID = {}; }; } // end anonymous namespace /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. -FunctionPass *mlir::createMemRefDataFlowOptPass() { +FunctionPassBase *mlir::createMemRefDataFlowOptPass() { return new MemRefDataFlowOpt(); } @@ -213,22 +209,22 @@ void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer<LoadOp> loadOp) { loadOpsToErase.push_back(loadOpInst); } -PassResult MemRefDataFlowOpt::runOnFunction(Function *f) { +PassResult MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. - if (f->getBlocks().size() != 1) + Function &f = getFunction(); + if (f.getBlocks().size() != 1) return success(); - DominanceInfo theDomInfo(f); + DominanceInfo theDomInfo(&f); domInfo = &theDomInfo; - PostDominanceInfo thePostDomInfo(f); + PostDominanceInfo thePostDomInfo(&f); postDomInfo = &thePostDomInfo; loadOpsToErase.clear(); memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. - f->walk<LoadOp>( - [&](OpPointer<LoadOp> loadOp) { forwardStoreToLoad(loadOp); }); + f.walk<LoadOp>([&](OpPointer<LoadOp> loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *loadOp : loadOpsToErase) { diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index f41f56efd8f..42e1446211b 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -38,21 +38,18 @@ using namespace mlir; namespace { -struct PipelineDataTransfer : public FunctionPass { - PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} - PassResult runOnFunction(Function *f) override; +struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> { + PassResult runOnFunction() override; PassResult runOnAffineForOp(OpPointer<AffineForOp> forOp); std::vector<OpPointer<AffineForOp>> forOps; - - constexpr static PassID passID = {}; }; } // end anonymous namespace /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. -FunctionPass *mlir::createPipelineDataTransferPass() { +FunctionPassBase *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } @@ -142,14 +139,14 @@ static bool doubleBuffer(Value *oldMemRef, OpPointer<AffineForOp> forOp) { } /// Returns success if the IR is in a valid state. -PassResult PipelineDataTransfer::runOnFunction(Function *f) { +PassResult PipelineDataTransfer::runOnFunction() { // Do a post order walk so that inner loop DMAs are processed first. This is // necessary since 'for' instructions nested within would otherwise become // invalid (erased) when the outer loop is pipelined (the pipelined one gets // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). forOps.clear(); - f->walkPostOrder<AffineForOp>( + getFunction().walkPostOrder<AffineForOp>( [&](OpPointer<AffineForOp> forOp) { forOps.push_back(forOp); }); bool ret = false; for (auto forOp : forOps) { diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index d0fdcb5527f..4c0fed5b648 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -36,18 +36,14 @@ namespace { /// the Function. This is mainly to test the simplifyAffineExpr method. /// TODO(someone): This should just be defined as a canonicalization pattern /// on AffineMap and driven from the existing canonicalization pass. -struct SimplifyAffineStructures : public FunctionPass { - explicit SimplifyAffineStructures() - : FunctionPass(&SimplifyAffineStructures::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct SimplifyAffineStructures + : public FunctionPass<SimplifyAffineStructures> { + PassResult runOnFunction() override; }; } // end anonymous namespace -FunctionPass *mlir::createSimplifyAffineStructuresPass() { +FunctionPassBase *mlir::createSimplifyAffineStructuresPass() { return new SimplifyAffineStructures(); } @@ -61,8 +57,8 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walk([&](Instruction *opInst) { +PassResult SimplifyAffineStructures::runOnFunction() { + getFunction().walk([&](Instruction *opInst) { for (auto attr : opInst->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) { MutableAffineMap mMap(mapAttr.getValue()); diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index fc2b0eb0a95..0f1ba02174b 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -23,26 +23,25 @@ using namespace mlir; namespace { -struct StripDebugInfo : public FunctionPass { - StripDebugInfo() : FunctionPass(&StripDebugInfo::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct StripDebugInfo : public FunctionPass<StripDebugInfo> { + PassResult runOnFunction() override; }; } // end anonymous namespace -PassResult StripDebugInfo::runOnFunction(Function *f) { - UnknownLoc unknownLoc = UnknownLoc::get(f->getContext()); +PassResult StripDebugInfo::runOnFunction() { + Function &func = getFunction(); + UnknownLoc unknownLoc = UnknownLoc::get(func.getContext()); // Strip the debug info from the function and its instructions. - f->setLoc(unknownLoc); - f->walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); + func.setLoc(unknownLoc); + func.walk([&](Instruction *inst) { inst->setLoc(unknownLoc); }); return success(); } /// Creates a pass to strip debug information from a function. -FunctionPass *mlir::createStripDebugInfoPass() { return new StripDebugInfo(); } +FunctionPassBase *mlir::createStripDebugInfoPass() { + return new StripDebugInfo(); +} static PassRegistration<StripDebugInfo> pass("strip-debuginfo", "Strip debug info from functions and instructions"); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 2363c5638ee..60e58c42e6b 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -83,20 +83,17 @@ static llvm::cl::opt<bool> clTestNormalizeMaps( namespace { -struct VectorizerTestPass : public FunctionPass { +struct VectorizerTestPass : public FunctionPass<VectorizerTestPass> { static constexpr auto kTestAffineMapOpName = "test_affine_map"; static constexpr auto kTestAffineMapAttrName = "affine_map"; - VectorizerTestPass() : FunctionPass(&VectorizerTestPass::passID) {} - PassResult runOnFunction(Function *f) override; + PassResult runOnFunction() override; void testVectorShapeRatio(Function *f); void testForwardSlicing(Function *f); void testBackwardSlicing(Function *f); void testSlicing(Function *f); void testComposeMaps(Function *f); void testNormalizeMaps(Function *f); - - constexpr static PassID passID = {}; }; } // end anonymous namespace @@ -263,11 +260,12 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { } } -PassResult VectorizerTestPass::runOnFunction(Function *f) { +PassResult VectorizerTestPass::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; // Only support single block functions at this point. + Function *f = &getFunction(); if (f->getBlocks().size() != 1) return success(); @@ -292,7 +290,7 @@ PassResult VectorizerTestPass::runOnFunction(Function *f) { return PassResult::Success; } -FunctionPass *mlir::createVectorizerTestPass() { +FunctionPassBase *mlir::createVectorizerTestPass() { return new VectorizerTestPass(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 5722b9d17da..8a378a29c84 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -651,12 +651,8 @@ static std::vector<NestedPattern> makePatterns() { namespace { -struct Vectorize : public FunctionPass { - Vectorize() : FunctionPass(&Vectorize::passID) {} - - PassResult runOnFunction(Function *f) override; - - constexpr static PassID passID = {}; +struct Vectorize : public FunctionPass<Vectorize> { + PassResult runOnFunction() override; }; } // end anonymous namespace @@ -1264,10 +1260,11 @@ static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. -PassResult Vectorize::runOnFunction(Function *f) { +PassResult Vectorize::runOnFunction() { // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; + Function *f = &getFunction(); for (auto &pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); @@ -1301,7 +1298,7 @@ PassResult Vectorize::runOnFunction(Function *f) { return PassResult::Success; } -FunctionPass *mlir::createVectorizePass() { return new Vectorize(); } +FunctionPassBase *mlir::createVectorizePass() { return new Vectorize(); } static PassRegistration<Vectorize> pass("vectorize", diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 14e21770e25..30fae94139f 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -73,18 +73,15 @@ void mlir::Function::viewGraph() const { } namespace { -struct PrintCFGPass : public FunctionPass { +struct PrintCFGPass : public FunctionPass<PrintCFGPass> { PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false, const llvm::Twine &title = "") - : FunctionPass(&PrintCFGPass::passID), os(os), shortNames(shortNames), - title(title) {} - PassResult runOnFunction(Function *function) override { - mlir::writeGraph(os, function, shortNames, title); + : os(os), shortNames(shortNames), title(title) {} + PassResult runOnFunction() { + mlir::writeGraph(os, &getFunction(), shortNames, title); return success(); } - constexpr static PassID passID = {}; - private: llvm::raw_ostream &os; bool shortNames; @@ -92,9 +89,9 @@ private: }; } // namespace -FunctionPass *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os, - bool shortNames, - const llvm::Twine &title) { +FunctionPassBase *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os, + bool shortNames, + const llvm::Twine &title) { return new PrintCFGPass(os, shortNames, title); } |

